Unverified Commit 952fbe47 authored by Yueyang Pan's avatar Yueyang Pan Committed by GitHub
Browse files

fix: fix the bug which leads qwen2_5_vl to crash with mixed_chunk (#11330)


Signed-off-by: default avatarPanJason <pyyjason@gmail.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: default avatarYuan Luo <yuan.luo@hotmail.com>
parent edb25693
......@@ -557,6 +557,25 @@ class ForwardBatch:
self.mrope_positions = next_input_positions
def _expand_mrope_from_input(
self,
mm_input: MultimodalInputs,
seq_len: int,
device: torch.device,
) -> torch.Tensor:
if mm_input.mrope_position_delta.device.type != device:
# transfer mrope_position_delta to device when the first running,
# avoiding successvie host-to-device data transfer
mm_input.mrope_position_delta = mm_input.mrope_position_delta.to(
device, non_blocking=True
)
mrope_position_deltas = mm_input.mrope_position_delta.flatten()
mrope_positions = (
(mrope_position_deltas + seq_len - 1).unsqueeze(0).repeat(3, 1)
)
return mrope_positions
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
......@@ -575,20 +594,10 @@ class ForwardBatch:
device=model_runner.device,
)
else:
if mm_input.mrope_position_delta.device.type != model_runner.device:
# transfer mrope_position_delta to device when the first running,
# avoiding successvie host-to-device data transfer
mm_input.mrope_position_delta = (
mm_input.mrope_position_delta.to(
model_runner.device, non_blocking=True
)
)
mrope_position_deltas = mm_input.mrope_position_delta.flatten()
mrope_positions_list[batch_idx] = (
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
.unsqueeze(0)
.repeat(3, 1)
mrope_positions = self._expand_mrope_from_input(
mm_input, self.seq_lens[batch_idx], model_runner.device
)
mrope_positions_list[batch_idx] = mrope_positions
elif self.forward_mode.is_extend():
extend_seq_len, extend_prefix_len = (
batch.extend_seq_lens[batch_idx],
......@@ -613,6 +622,10 @@ class ForwardBatch:
:,
extend_prefix_len : extend_prefix_len + extend_seq_len,
]
if mrope_positions.numel() == 0:
mrope_positions = self._expand_mrope_from_input(
mm_input, self.seq_lens[batch_idx], model_runner.device
)
mrope_positions_list[batch_idx] = mrope_positions
self.mrope_positions = torch.cat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment