Unverified Commit f05d474c authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Model][Qwen3VL] Use `mm_position` to compute mrope positions (#28730)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 9fc81ec7
......@@ -24,7 +24,7 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from itertools import islice
from typing import Any
......@@ -1412,72 +1412,47 @@ class Qwen3VLForConditionalGeneration(
)
return mm_input_by_modality
def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int]]:
video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, h // spatial_merge_size, w // spatial_merge_size
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)]
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
input_tokens_array = np.array(input_tokens)
vision_start_mask = input_tokens_array == vision_start_token_id
vision_tokens = input_tokens_array[vision_start_mask.nonzero()[0] + 1]
image_nums = np.count_nonzero(vision_tokens == image_token_id)
video_nums = np.count_nonzero(vision_tokens == video_token_id)
llm_pos_ids_list: list = []
llm_pos_ids_list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features
):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment