Unverified Commit feda9b11 authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: fix one more bug from merging mm_inputs (#5718)


Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarXinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
parent c3948ba6
......@@ -1040,15 +1040,18 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(
context_len + mrope_position_delta, seq_len + mrope_position_delta
) -> torch.Tensor:
return torch.tensor(
[
list(
range(
context_len + mrope_position_delta,
seq_len + mrope_position_delta,
)
)
)
for _ in range(3)
]
for _ in range(3)
]
)
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......
......@@ -351,7 +351,6 @@ class MultimodalInputs:
optional_args = [
"mm_items",
"image_pad_len",
"mrope_position_delta",
]
for arg in optional_args:
self_arg = getattr(self, arg, None)
......@@ -367,6 +366,14 @@ class MultimodalInputs:
[self.mrope_positions, other.mrope_positions], dim=1
)
mrope_position_delta = self.mrope_position_delta
if mrope_position_delta is not None:
if other.mrope_position_delta is None:
self.mrope_position_delta = mrope_position_delta
else:
self.mrope_position_delta = torch.cat(
[self.mrope_position_delta, other.mrope_position_delta], dim=0
)
# other args would be kept intact
......@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.cat(
[self.req_pool_indices, other.req_pool_indices]
)
......
......@@ -38,7 +38,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -364,23 +364,23 @@ class ForwardBatch:
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
"""
Merge all image inputs in the batch into a single MultiModalInputs object.
Merge all multimodal inputs in the batch into a single MultiModalInputs object.
Returns:
if none, current batch contains no image input
if none, current batch contains no multimodal input
"""
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
return None
# Filter out None values
valid_inputs = [x for x in self.mm_inputs if x is not None]
# Start with the first valid image input
merged = valid_inputs[0]
# TODO: is it expensive?
# a workaround to avoid importing `MultimodalInputs`
merged = valid_inputs[0].__class__(mm_items=[])
# Merge remaining inputs
for mm_input in valid_inputs[1:]:
for mm_input in valid_inputs:
merged.merge(mm_input)
return merged
......@@ -407,26 +407,34 @@ class ForwardBatch:
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.multimodal_inputs[i] is None
else batch.multimodal_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = torch.tensor(
MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
# batch_size * [3 * seq_len]
batch_size = self.seq_lens.shape[0]
mrope_positions_list = [[]] * batch_size
for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx]
if self.forward_mode.is_decode():
mrope_position_deltas = (
[0]
if mm_input is None
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
)
elif self.forward_mode.is_extend():
for i, mm_input in enumerate(batch.multimodal_inputs):
next_input_positions = []
for mrope_position_delta in mrope_position_deltas:
# batched deltas needs to be processed separately
# Convert list of lists to tensor with shape [3, seq_len]
next_input_positions += [
MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[batch_idx]) - 1,
int(self.seq_lens[batch_idx]),
)
]
# 3 * N
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
elif self.forward_mode.is_extend():
extend_seq_len, extend_prefix_len = (
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
batch.extend_seq_lens[batch_idx],
batch.extend_prefix_lens[batch_idx],
)
if mm_input is None:
# text only
......@@ -447,13 +455,12 @@ class ForwardBatch:
:,
extend_prefix_len : extend_prefix_len + extend_seq_len,
]
mrope_positions_list[i] = mrope_positions
mrope_positions_list[batch_idx] = mrope_positions
self.mrope_positions = torch.cat(
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
dim=1,
).to(device=model_runner.device)
self.mrope_positions = self.mrope_positions.to(torch.int64)
).to(dtype=torch.int64, device=model_runner.device)
def get_max_chunk_capacity(self):
# Maximum number of tokens in each chunk
......
......@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertGreater(len(video_response), 0)
def test_regex(self):
return
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment