Unverified Commit a8aad935 authored by yizhang2077's avatar yizhang2077 Committed by GitHub
Browse files

qwen2vl fix bug for #1971 #1897 (#1984)

parent 47ffe7af
...@@ -133,6 +133,7 @@ class ImageInputs: ...@@ -133,6 +133,7 @@ class ImageInputs:
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related # QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod @staticmethod
def from_dict(obj, vocab_size): def from_dict(obj, vocab_size):
...@@ -251,9 +252,6 @@ class Req: ...@@ -251,9 +252,6 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
...@@ -983,8 +981,6 @@ class ScheduleBatch: ...@@ -983,8 +981,6 @@ class ScheduleBatch:
global bid global bid
bid += 1 bid += 1
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
return ModelWorkerBatch( return ModelWorkerBatch(
bid=bid, bid=bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
...@@ -1007,7 +1003,6 @@ class ScheduleBatch: ...@@ -1007,7 +1003,6 @@ class ScheduleBatch:
encoder_out_cache_loc=self.encoder_out_cache_loc, encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs], lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
) )
def copy(self): def copy(self):
...@@ -1074,9 +1069,6 @@ class ModelWorkerBatch: ...@@ -1074,9 +1069,6 @@ class ModelWorkerBatch:
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
# For Qwen2-VL
mrope_positions_delta: List[List[int]]
def copy(self): def copy(self):
return dataclasses.replace(self, sampling_info=self.sampling_info.copy()) return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
......
...@@ -136,8 +136,13 @@ class ForwardBatch: ...@@ -136,8 +136,13 @@ class ForwardBatch:
mrope_positions_list = [None] * self.seq_lens.shape[0] mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list): for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
batch.mrope_positions_delta[i][0], mrope_position_delta,
int(self.seq_lens[i]) - 1, int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]), int(self.seq_lens[i]),
) )
...@@ -159,7 +164,6 @@ class ForwardBatch: ...@@ -159,7 +164,6 @@ class ForwardBatch:
) )
] ]
] * 3 ] * 3
mrope_position_delta = 0
else: else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation # TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = ( mrope_positions, mrope_position_delta = (
...@@ -173,8 +177,8 @@ class ForwardBatch: ...@@ -173,8 +177,8 @@ class ForwardBatch:
context_len=0, context_len=0,
) )
) )
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta)
self.mrope_positions = torch.concat( self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list], [torch.tensor(pos, device=device) for pos in mrope_positions_list],
......
...@@ -649,8 +649,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -649,8 +649,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
] ]
image_embeds_offset += num_image_tokens image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment