Unverified Commit feda9b11 authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: fix one more bug from merging mm_inputs (#5718)


Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avatarXinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
parent c3948ba6
...@@ -1040,15 +1040,18 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1040,15 +1040,18 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta: int, mrope_position_delta: int,
context_len: int, context_len: int,
seq_len: int, seq_len: int,
) -> List[List[int]]: ) -> torch.Tensor:
return [ return torch.tensor(
list( [
range( list(
context_len + mrope_position_delta, seq_len + mrope_position_delta range(
context_len + mrope_position_delta,
seq_len + mrope_position_delta,
)
) )
) for _ in range(3)
for _ in range(3) ]
] )
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......
...@@ -351,7 +351,6 @@ class MultimodalInputs: ...@@ -351,7 +351,6 @@ class MultimodalInputs:
optional_args = [ optional_args = [
"mm_items", "mm_items",
"image_pad_len", "image_pad_len",
"mrope_position_delta",
] ]
for arg in optional_args: for arg in optional_args:
self_arg = getattr(self, arg, None) self_arg = getattr(self, arg, None)
...@@ -367,6 +366,14 @@ class MultimodalInputs: ...@@ -367,6 +366,14 @@ class MultimodalInputs:
[self.mrope_positions, other.mrope_positions], dim=1 [self.mrope_positions, other.mrope_positions], dim=1
) )
mrope_position_delta = self.mrope_position_delta
if mrope_position_delta is not None:
if other.mrope_position_delta is None:
self.mrope_position_delta = mrope_position_delta
else:
self.mrope_position_delta = torch.cat(
[self.mrope_position_delta, other.mrope_position_delta], dim=0
)
# other args would be kept intact # other args would be kept intact
...@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens]) self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu) self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.cat( self.req_pool_indices = torch.cat(
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
......
...@@ -38,7 +38,7 @@ import triton ...@@ -38,7 +38,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import flatten_nested_list, get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -364,23 +364,23 @@ class ForwardBatch: ...@@ -364,23 +364,23 @@ class ForwardBatch:
def merge_mm_inputs(self) -> Optional[MultimodalInputs]: def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
""" """
Merge all image inputs in the batch into a single MultiModalInputs object. Merge all multimodal inputs in the batch into a single MultiModalInputs object.
Returns: Returns:
if none, current batch contains no image input if none, current batch contains no multimodal input
""" """
if not self.mm_inputs or all(x is None for x in self.mm_inputs): if not self.mm_inputs or all(x is None for x in self.mm_inputs):
return None return None
# Filter out None values # Filter out None values
valid_inputs = [x for x in self.mm_inputs if x is not None] valid_inputs = [x for x in self.mm_inputs if x is not None]
# Start with the first valid image input # TODO: is it expensive?
merged = valid_inputs[0] # a workaround to avoid importing `MultimodalInputs`
merged = valid_inputs[0].__class__(mm_items=[])
# Merge remaining inputs # Merge remaining inputs
for mm_input in valid_inputs[1:]: for mm_input in valid_inputs:
merged.merge(mm_input) merged.merge(mm_input)
return merged return merged
...@@ -407,26 +407,34 @@ class ForwardBatch: ...@@ -407,26 +407,34 @@ class ForwardBatch:
def _compute_mrope_positions( def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
): ):
mrope_positions_list = [None] * self.seq_lens.shape[0] # batch_size * [3 * seq_len]
if self.forward_mode.is_decode(): batch_size = self.seq_lens.shape[0]
for i, _ in enumerate(mrope_positions_list): mrope_positions_list = [[]] * batch_size
mrope_position_delta = ( for batch_idx in range(batch_size):
0 mm_input = batch.multimodal_inputs[batch_idx]
if batch.multimodal_inputs[i] is None if self.forward_mode.is_decode():
else batch.multimodal_inputs[i].mrope_position_delta mrope_position_deltas = (
) [0]
mrope_positions_list[i] = torch.tensor( if mm_input is None
MRotaryEmbedding.get_next_input_positions( else flatten_nested_list(mm_input.mrope_position_delta.tolist())
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
) )
elif self.forward_mode.is_extend(): next_input_positions = []
for i, mm_input in enumerate(batch.multimodal_inputs): for mrope_position_delta in mrope_position_deltas:
# batched deltas needs to be processed separately
# Convert list of lists to tensor with shape [3, seq_len]
next_input_positions += [
MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[batch_idx]) - 1,
int(self.seq_lens[batch_idx]),
)
]
# 3 * N
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
elif self.forward_mode.is_extend():
extend_seq_len, extend_prefix_len = ( extend_seq_len, extend_prefix_len = (
batch.extend_seq_lens[i], batch.extend_seq_lens[batch_idx],
batch.extend_prefix_lens[i], batch.extend_prefix_lens[batch_idx],
) )
if mm_input is None: if mm_input is None:
# text only # text only
...@@ -447,13 +455,12 @@ class ForwardBatch: ...@@ -447,13 +455,12 @@ class ForwardBatch:
:, :,
extend_prefix_len : extend_prefix_len + extend_seq_len, extend_prefix_len : extend_prefix_len + extend_seq_len,
] ]
mrope_positions_list[i] = mrope_positions mrope_positions_list[batch_idx] = mrope_positions
self.mrope_positions = torch.cat( self.mrope_positions = torch.cat(
[pos.to(device=model_runner.device) for pos in mrope_positions_list], [pos.to(device=model_runner.device) for pos in mrope_positions_list],
dim=1, dim=1,
).to(device=model_runner.device) ).to(dtype=torch.int64, device=model_runner.device)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def get_max_chunk_capacity(self): def get_max_chunk_capacity(self):
# Maximum number of tokens in each chunk # Maximum number of tokens in each chunk
......
...@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertGreater(len(video_response), 0) self.assertGreater(len(video_response), 0)
def test_regex(self): def test_regex(self):
return
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = ( regex = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment