Unverified Commit 70c73df6 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

[Bugfix] Fix EVS implementation for Qwen3 VL (#33607)


Signed-off-by: default avatar2ez4bz <133824995+2ez4bz@users.noreply.github.com>
parent 9a9d4424
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import random
from dataclasses import dataclass
import pytest
import torch
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
# _get_mrope_input_positions returns CPU tensors (via torch.from_numpy).
# Ensure the default device is CPU so the rest of the test tensors match.
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
IMAGE_TOKEN_ID = 999
VIDEO_TOKEN_ID = 888
VISION_START_TOKEN_ID = 777
VISION_END_TOKEN_ID = 778
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 1
@dataclass
class DummyConfig:
image_token_id: int = IMAGE_TOKEN_ID
video_token_id: int = VIDEO_TOKEN_ID
vision_start_token_id: int = VISION_START_TOKEN_ID
vision_end_token_id: int = VISION_END_TOKEN_ID
vision_config: DummyVisionConfig = dataclasses.field(
default_factory=DummyVisionConfig
)
def make_video_embedding(
t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0
):
"""
Helper function to make a video embedding for a given video size and pruning rate.
Args:
t: Number of frames.
h: Number of rows.
w: Number of columns.
interleave_text_tokens: Tuple of minimum and maximum number of text tokens to
interleave with the video.
video_pruning_rate: Pruning rate for the video.
Returns:
Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask)
"""
unpruned_tokens_sequence = []
population = list(range(1, 100))
for _ in range(t):
num_prefix_tokens = random.randint(
interleave_text_tokens[0], interleave_text_tokens[1]
)
prefix_tokens = random.choices(population, k=num_prefix_tokens)
vision_tokens = (
[VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID]
)
unpruned_tokens_sequence.extend(prefix_tokens)
unpruned_tokens_sequence.extend(vision_tokens)
unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long)
video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID
pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined]
# Sanity check that we don't prune what should not be pruned.
assert not pruning_mask[~video_token_mask].any()
retention_mask = ~pruning_mask
pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask]
return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask
@pytest.mark.parametrize("spatial_merge_size", [1, 2])
@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]])
@pytest.mark.parametrize("num_prefix_tokens", [1, 11])
@pytest.mark.parametrize("num_suffix_tokens", [0, 7])
@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75])
@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)])
def test_match_qwen3vl_mrope_evs_on(
spatial_merge_size: int,
num_prefix_tokens: int,
grid_thw: tuple[int, int, int],
num_suffix_tokens: int,
video_pruning_rate: float,
interleave_text_tokens: tuple[int, int],
):
hf_config = DummyConfig()
hf_config.vision_config.spatial_merge_size = spatial_merge_size
t, h, w = grid_thw
population = list(range(1, 100))
prefix_tokens = random.choices(population, k=num_prefix_tokens)
suffix_tokens = random.choices(population, k=num_suffix_tokens)
video_tokens, video_tokens_pruned, retention_mask = make_video_embedding(
t,
h // spatial_merge_size,
w // spatial_merge_size,
interleave_text_tokens=interleave_text_tokens,
video_pruning_rate=video_pruning_rate,
)
assert len(video_tokens) == len(retention_mask)
input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens
input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens
whole_sequence_retention_mask = torch.cat(
[
torch.ones(len(prefix_tokens), dtype=torch.bool),
retention_mask,
torch.ones(len(suffix_tokens), dtype=torch.bool),
],
dim=0,
)
# Build the GT mrope for unpruned input.
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=len(input_tokens)),
)
expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=[mm_feature],
config=hf_config,
)
# Compute mrope for a video-only media (unpruned).
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()),
)
video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
input_tokens=video_tokens.tolist(),
mm_features=[mm_feature],
config=hf_config,
)
video_mrope = video_mrope.permute(1, 0) # [N, 3]
hidden_size = 16
is_video_embed = torch.isin(
video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long)
)
expanded_positions = torch.full(
(len(video_tokens_pruned), 5),
fill_value=-100,
device=video_mrope.device,
dtype=torch.long,
)
expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed]
expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][
~is_video_embed
]
is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID
expanded_positions[..., 3] = is_vision_start
expanded_positions[..., 4] = is_video_embed
# Check that all positions were filled, since we initialized them as negative.
assert (expanded_positions >= 0).all()
video_embeddings = torch.empty(
(len(video_tokens_pruned), hidden_size), device=video_mrope.device
)
video_embeddings = torch.cat(
[
video_embeddings,
expanded_positions.float(),
],
dim=1,
)
multimodal_embeddings = [video_embeddings]
expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask]
# Initialize computed_mrope with sequential positions for all prefix tokens
computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long)
computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[
:, 0 : len(prefix_tokens)
]
# Paranoia check that computed_mrope is wrong.
assert not torch.equal(computed_mrope, expected_mrope_masked)
_, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions(
input_ids=input_tokens_pruned,
multimodal_embeddings=multimodal_embeddings,
mrope_positions=computed_mrope,
num_computed_tokens=len(prefix_tokens),
vision_start_token_id=hf_config.vision_start_token_id,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
)
assert torch.equal(actual_mrope, expected_mrope_masked)
......@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
"""
type: Literal["pixel_values_videos"]
......@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
TensorShape("nv"),
]
timestamps: list[list[float]] | None = None
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
"""
......@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
"""
type: Literal["video_embeds"]
......@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
torch.Tensor | None,
TensorShape("nv"),
] = None
timestamps: list[list[float]] | None = None
Qwen2_5_VLVideoInputs: TypeAlias = (
......
......@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory(
"video", video_embed_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
return _qwen2vl_field_config
......
......@@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
......@@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
# Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = False
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
......@@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
return inputs_embeds
def recompute_mrope_positions(self, *args, **kwargs):
raise NotImplementedError(
"Qwen3.5 does not support multimodal pruning (EVS). "
"recompute_mrope_positions should never be called."
)
def forward(
self,
input_ids: torch.Tensor,
......@@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
# Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = False
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
......
......@@ -79,6 +79,7 @@ from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
......@@ -93,6 +94,8 @@ from vllm.multimodal.processing import (
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up
......@@ -763,7 +766,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
def _get_video_second_idx(
self,
metadata: dict[str, Any],
out_item: MultiModalKwargsItem,
do_sample_frames: bool | None = None,
sampled_fps: float | None = None,
) -> list[int]:
......@@ -956,6 +958,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
if videos := mm_data.pop("videos", []):
video_grid_thw_lst = []
pixel_values_videos_lst = []
timestamps_per_video = []
for item in videos:
video_array, metadata = item
......@@ -979,6 +982,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
**{k: metadata[k] for k in metadata if k != "do_sample_frames"}
)
# Compute timestamps here where we have access to metadata
timestamps = self.info._get_video_second_idx(
metadata=metadata,
do_sample_frames=video_mm_kwargs["do_sample_frames"],
sampled_fps=video_mm_kwargs.get("fps"),
)
timestamps_per_video.append(timestamps)
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
......@@ -989,6 +1000,49 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
merge_size = processor.video_processor.merge_size
# Get video grid info for EVS calculation.
video_grid_thw = video_outputs["video_grid_thw"]
num_frames = int(video_grid_thw[0, 0])
tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // (
merge_size**2
)
# Apply EVS if enabled.
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame_base,
num_frames=num_frames,
q=video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame.
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
select_token_id = False
else:
tokens_per_frame = [tokens_per_frame_base] * num_frames
select_token_id = True
# Generate the video replacement with EVS-adjusted token counts
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()
video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=tokens_per_frame,
timestamps=timestamps,
tokenizer=tokenizer,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
video_token_id=hf_config.video_token_id,
select_token_id=select_token_id,
)
# Convert token IDs to text for the HF processor flow
video_placeholder = tokenizer.decode(
video_repl.full, skip_special_tokens=False
)
input_ids = video_outputs.pop("input_ids")
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(
......@@ -1002,6 +1056,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_grid_thw=torch.cat(video_grid_thw_lst),
timestamps=timestamps_per_video,
)
else:
video_outputs = dict()
......@@ -1057,60 +1112,42 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
grid_thw = out_item["video_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
sampled_fps = hf_processor_mm_kwargs.get("fps")
if is_list_of(sampled_fps, float):
sampled_fps = sampled_fps[item_idx]
timestamps = self.info._get_video_second_idx(
metadata, out_item, do_sample_frames, sampled_fps
)
timestamps = out_item["timestamps"].data
assert len(timestamps) == grid_thw[0], (
f"The timestamps length({len(timestamps)}) should be equal "
f"video length ({grid_thw[0]})."
)
frames_idx_token = [
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
for curr_time in timestamps
]
tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
# Compute tokens per frame, with EVS support
num_frames = int(grid_thw[0])
tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
total_retained = compute_retained_tokens_count(
tokens_per_frame,
len(frames_idx_token),
video_pruning_rate,
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame_base,
num_frames=num_frames,
q=video_pruning_rate,
)
if len(frames_idx_token) == 0:
per_frame_token_counts = []
elif len(frames_idx_token) == 1:
per_frame_token_counts = [tokens_per_frame]
else:
first_frame_tokens = tokens_per_frame
remaining_tokens = max(total_retained - first_frame_tokens, 0)
base = remaining_tokens // (len(frames_idx_token) - 1)
remainder = remaining_tokens % (len(frames_idx_token) - 1)
per_frame_token_counts = [first_frame_tokens]
for frame_idx in range(1, len(frames_idx_token)):
extra = base + (1 if (frame_idx - 1) < remainder else 0)
per_frame_token_counts.append(extra)
placeholder = []
for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
placeholder.extend(timestamp_tokens)
tokens_this_frame = per_frame_token_counts[
frame_idx if frame_idx < len(per_frame_token_counts) else -1
]
placeholder.extend(
[vision_start_token_id]
+ [video_token_id] * tokens_this_frame
+ [vision_end_token_id]
)
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
select_token_id = False
else:
tokens_per_frame = [tokens_per_frame_base] * num_frames
select_token_id = True
return Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=tokens_per_frame,
timestamps=timestamps,
tokenizer=tokenizer,
vision_start_token_id=vision_start_token_id,
vision_end_token_id=vision_end_token_id,
video_token_id=video_token_id,
select_token_id=select_token_id,
)
return [
PromptReplacement(
......@@ -1127,6 +1164,69 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
),
]
@staticmethod
def get_video_repl(
*,
tokens_per_frame: list[int],
timestamps: list[float | int],
tokenizer: TokenizerLike,
vision_start_token_id: int,
vision_end_token_id: int,
video_token_id: int,
select_token_id: bool = False,
) -> PromptUpdateDetails[list[int]]:
"""Build prompt replacement for a video in Qwen3VL format.
The replacement structure for each frame is:
timestamp_tokens + vision_start_token + video_tokens + vision_end_token
Args:
tokens_per_frame: Number of video tokens per frame (can vary per frame for
EVS).
timestamps: List of timestamps in seconds for each frame
tokenizer: Tokenizer to encode timestamp strings
vision_start_token_id: Token ID for vision start marker
vision_end_token_id: Token ID for vision end marker
video_token_id: Token ID for video content
Returns:
PromptUpdateDetails with full token sequence
"""
assert len(timestamps) == len(tokens_per_frame), (
"timestamps and tokens_per_frame must have the same length"
)
# Tokenize timestamp strings independently to avoid tokenizer merging
# tokens across boundaries.
# TODO: switch to `_seq2tokens` which has some caching.
timestamp_token_ids = [
tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False)
for timestamp in timestamps
]
# Build the full token sequence
all_token_ids = []
for frame_timestamp_ids, num_tokens in zip(
timestamp_token_ids, tokens_per_frame
):
# Add timestamp tokens
all_token_ids.extend(frame_timestamp_ids)
# Add vision tokens: vision_start + video_tokens + vision_end
all_token_ids.append(vision_start_token_id)
all_token_ids.extend([video_token_id] * num_tokens)
all_token_ids.append(vision_end_token_id)
if select_token_id:
return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)
# NOTE: we use `from_seq` instead of `select_token_id` because we want all
# tokens in the placeholder to be initially marked as candidates. Then
# in `get_input_embeddings``, we refine the mask to only replace
# `video_token_id` / `image_token_id`` positions with video/image embeddings,
# keeping text embeddings for timestamps and structural tokens.
return PromptUpdateDetails.from_seq(all_token_ids)
@support_torch_compile(
dynamic_arg_dims={
......@@ -1280,6 +1380,7 @@ class Qwen3VLForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
......@@ -1419,6 +1520,7 @@ class Qwen3VLForConditionalGeneration(
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
timestamps = kwargs.pop("timestamps", None)
if pixel_values_videos is None and video_embeds is None:
return None
......@@ -1429,6 +1531,7 @@ class Qwen3VLForConditionalGeneration(
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
timestamps=timestamps,
)
if video_embeds is not None:
......@@ -1436,6 +1539,7 @@ class Qwen3VLForConditionalGeneration(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
timestamps=timestamps,
)
def _process_image_input(
......@@ -1502,19 +1606,29 @@ class Qwen3VLForConditionalGeneration(
Returns:
Tuple of image embeddings for each image item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
Resulting embeddings will have extra 5 channels for
computed mrope positions, consistent with video embeddings.
"""
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = image_embeds_out
return tuple(image_embeds_split)
if self.is_multimodal_pruning_enabled:
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
positions = torch.cat(
[
positions,
torch.zeros_like(
positions[:, 0:1]
), # Dummy extra fifth channel
],
dim=1,
)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = tuple(image_embeds_out)
return image_embeds_split
def _postprocess_video_embeds_evs(
self,
......@@ -1531,62 +1645,218 @@ class Qwen3VLForConditionalGeneration(
Returns:
Tuple of video embeddings for each video item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
Resulting embeddings will have extra 5 channels for computed mrope
positions, and whether the index corresponds to a video embedding.
"""
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
second_per_grid_ts = video_input.get("second_per_grid_ts")
if second_per_grid_ts is None:
# For Qwen3-VL, second_per_grid_ts might not be available
# Use default value of 1.0 for each video
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
# Apply EVS to each video.
video_embeds_out = []
for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)):
# Compute positions.
timestamps = video_input.timestamps[video_idx]
num_frames = len(timestamps)
t, h, w = size
if self.is_multimodal_pruning_enabled:
# For each video, compute retention mask using EVS.
# retention_mask: [11424].
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
# Apply retention mask.
emb = emb[retention_mask]
# Calculate the actual number of retained tokens per frame.
num_frames, rows, cols = (
t,
h // merge_size,
w // merge_size,
)
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
num_tokens_per_frame = (
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
)
else:
feature_size = emb.shape[0] // num_frames
num_tokens_per_frame = [feature_size] * num_frames
retention_mask = None
emb = self._create_final_video_embeddings(
video_embeddings=emb,
num_tokens_per_frame=num_tokens_per_frame,
timestamps=timestamps,
video_grid_thw=size,
retention_mask=retention_mask,
)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def _create_final_video_embeddings(
self,
video_embeddings: torch.Tensor,
num_tokens_per_frame: list[int],
timestamps: list[float],
video_grid_thw: list[int],
retention_mask: torch.Tensor,
) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with
text embeddings of indicator tokens.
These final embeddings contain:
- Actual video embeddings in positions corresponding to video content
- Text embeddings for indicator tokens (<img>, </img>, and
frame separation text) in their respective positions
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
# special tokens to ensure consistent tokenization regardless of
# num_tokens_per_frame values.
video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=num_tokens_per_frame,
tokenizer=self._tokenizer,
timestamps=timestamps,
vision_start_token_id=self.config.vision_start_token_id,
vision_end_token_id=self.config.vision_end_token_id,
video_token_id=self.config.video_token_id,
select_token_id=self.is_multimodal_pruning_enabled,
)
repl_token_ids = torch.tensor(video_repl.full, device=device)
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
is_video_embed = torch.isin(repl_token_ids, embed_token_id)
# Get text embeddings for indicator tokens (has only `visual_dim``).
text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)
if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=text_embeddings,
multimodal_embeddings=[video_embeddings],
is_multimodal=is_video_embed,
)
else:
second_per_grid_ts = second_per_grid_ts.long()
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
deepstack_input_embeds = None
multimodal_embeddings = [video_embeddings]
video_embeds_out = []
for emb, size, video_second_per_grid_t in zip(
video_embeds_split, grid_thw_list, second_per_grid_ts
):
# For each video, we compute retention mask using EVS
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
merged_embeddings = _merge_multimodal_embeddings(
inputs_embeds=text_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_video_embed,
)
to_concat = [merged_embeddings]
if deepstack_input_embeds is not None:
to_concat.append(
deepstack_input_embeds.permute(1, 0, 2).reshape(
deepstack_input_embeds.shape[1], -1
)
)
# Debug logging for EVS pruning
logger.debug(
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
"pruning_rate=%.2f, reduction=%.1f%%)",
emb.shape[0],
retention_mask.sum().item(),
size[0],
size[1],
size[2],
self.video_pruning_rate,
(1 - retention_mask.float().mean().item()) * 100,
expanded_positions = None
if self.is_multimodal_pruning_enabled:
is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id)
expanded_positions = self._get_expanded_positions(
device=merged_embeddings.device,
seq_len=merged_embeddings.shape[0],
video_grid_thw=video_grid_thw,
num_tokens_per_frame=num_tokens_per_frame,
timestamps=timestamps,
is_video_embed=is_video_embed,
is_vision_start=is_vision_start,
retention_mask=retention_mask,
)
to_concat.append(expanded_positions)
positions = compute_mrope_for_media(
size,
merge_size,
tokens_per_second=tokens_per_second,
video_second_per_grid=video_second_per_grid_t.item(),
).to(emb.device)
final_video_embeddings = torch.cat(to_concat, dim=-1)
emb = emb[retention_mask]
positions = positions[retention_mask]
emb = torch.cat([emb, positions], dim=1)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
return final_video_embeddings
def _get_expanded_positions(
self,
device,
seq_len,
video_grid_thw,
num_tokens_per_frame,
timestamps,
is_video_embed,
is_vision_start,
retention_mask,
):
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
# Expand positions to match the full sequence length
# (includes both video tokens and indicator tokens)
# Shape: [full_length, 5] where positions are filled for video tokens
# and zeros for indicator tokens.
# Channel 3 flags VISION_START tokens so that
# recompute_mrope_positions can reliably count timestamp tokens
# (even when early frames have all video tokens pruned).
# Channel 4 flags video-embedding tokens.
expanded_positions = torch.zeros(
seq_len,
5, # [t_index, h_index, w_index, is_vision_start, is_video]
device=device,
dtype=torch.long,
)
_, h, w = video_grid_thw
merge_size = self.visual.spatial_merge_size
num_frames = len(num_tokens_per_frame)
unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames,
tokenizer=self._tokenizer,
timestamps=timestamps,
vision_start_token_id=self.config.vision_start_token_id,
vision_end_token_id=self.config.vision_end_token_id,
video_token_id=self.config.video_token_id,
).full
unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(video_grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)),
)
original_mrope = (
self.get_mrope_input_positions(
input_tokens=unpruned_token_ids,
mm_features=[mm_feature],
)[0]
.to(device)
.permute(1, 0)
)
full_is_video_embed = unpruned_token_ids_tensor == embed_token_id
expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][
retention_mask
]
expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed]
expanded_positions[..., 3] = is_vision_start
expanded_positions[..., 4] = is_video_embed
return expanded_positions
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
......@@ -1607,66 +1877,77 @@ class Qwen3VLForConditionalGeneration(
)
return mm_input_by_modality
def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int]]:
"""
Iterate over multimodal features and yield grid information.
For videos with EVS (Efficient Video Sampling) enabled, this function
computes the offset based on the pruned token count rather than relying
on input_tokens.index(), which would fail when tokens are pruned.
@staticmethod
def _iter_mm_grid_hw(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
) -> Iterator[tuple[int, int, int, int]]:
"""Iterate over multimodal features and yield position info.
Args:
input_tokens: List of token IDs in the prompt
mm_features: List of multimodal feature specifications
input_tokens: List of token IDs in the input sequence.
mm_features: List of multimodal feature specifications containing
image/video data and position information.
video_token_id: Token ID used for video tokens.
vision_start_token_id: Token ID marking the start of a vision sequence.
vision_end_token_id: Token ID marking the end of a vision sequence.
spatial_merge_size: Size of the spatial merge operation used to
compute logical grid dimensions from the original feature grid.
Yields:
Tuple of (offset, grid_h, grid_w) for each frame/image
offset: Position of the first video/image token in the sequence.
llm_grid_h: Logical grid height (may not match actual token count with EVS).
llm_grid_w: Logical grid width (may not match actual token count with EVS).
actual_num_tokens: Actual number of video/image tokens in the placeholder.
"""
video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, h // spatial_merge_size, w // spatial_merge_size
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
# Check if EVS (Efficient Video Sampling) is enabled
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
frame_offsets = self._extract_frame_offsets_from_mask(
mm_feature.mm_position, t
)
if frame_offsets is not None:
for rel_offset in frame_offsets:
yield offset + rel_offset, llm_grid_h, llm_grid_w
continue
# If EVS is enabled but mask is missing, this indicates a bug
# in the prompt processing pipeline. The is_embed mask should
# always be present when video_pruning_rate > 0.
raise RuntimeError(
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
"but is_embed mask is missing from mm_position. "
"This indicates a bug in prompt processing."
)
else:
# Non-EVS mode: Use original logic with input_tokens.index()
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
for _ in range(t):
# When EVS is enabled, some frames may have 0 video tokens in the
# placeholder. We use `vision_start_token_id` to locate each frame
# since it is always present for every frame.
# We then look for the first `video_token_id` after
# `vision_start_token_id` and before `vision_end_token_id`.
offset = input_tokens.index(vision_start_token_id, offset)
vision_end_offset = input_tokens.index(vision_end_token_id, offset)
try:
actual_num_tokens = 0
video_offset = input_tokens.index(
video_token_id, offset, vision_end_offset
)
# NOTE: looking at the
# `Qwen3VLMultiModalProcessor.get_video_repl` code, we can
# see that we can use the below formula to get the token
# count, since everything in between `video_offset` and
# `vision_end_offset` is populated as `video_token_id`.
# This saves us from manually counting the number tokens
# that match `video_token_id` in between.
actual_num_tokens += vision_end_offset - video_offset
except ValueError:
# No `video_token_id` in this frame (EVS with 0 tokens for
# this frame) -> use `offset + 1`` to move past
# `vision_start_token_id`.
video_offset = offset + 1
yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens
# Move offset past this frame for next iteration.
offset = vision_end_offset + 1
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
......@@ -1771,13 +2052,100 @@ class Qwen3VLForConditionalGeneration(
return [len(seg) for seg in segments]
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
return self._get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
config=self.config,
)
@staticmethod
def _get_mrope_input_positions(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
config: Qwen3VLConfig,
):
llm_pos_ids_list = []
st = 0
for (
offset,
llm_grid_h,
llm_grid_w,
actual_num_tokens,
) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw(
input_tokens,
mm_features,
video_token_id=config.video_token_id,
vision_start_token_id=config.vision_start_token_id,
vision_end_token_id=config.vision_end_token_id,
spatial_merge_size=config.vision_config.spatial_merge_size,
):
# Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere)
if actual_num_tokens == 0:
continue
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
# Check if this is a "lumped placeholder" (all tokens from multiple frames
# assigned to the 0-th frame - see
# `Qwen3VLMultiModalProcessor.get_video_repl`.
expected_tokens_per_frame = llm_grid_h * llm_grid_w
if actual_num_tokens > expected_tokens_per_frame:
# Lumped placeholder: create grid positions for all "logical" frames
# represented.
num_logical_frames = actual_num_tokens // expected_tokens_per_frame
remainder = actual_num_tokens % expected_tokens_per_frame
# Create positions for complete frames.
for _ in range(num_logical_frames):
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(
3, -1
)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1
text_len = 0 # No text between frames within the lump
# Handle remainder tokens if any (partial frame).
# NOTE: this should never be the case. Should we have an assert?
if remainder > 0:
# Create a partial grid - take first 'remainder' positions
full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
grid_indices = full_grid[:, :remainder]
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
else:
# Normal case: frame has exactly the expected tokens (after actual EVS
# pruning).
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + actual_num_tokens
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def recompute_mrope_positions(
self,
input_ids: list[int],
multimodal_embeddings: tuple[torch.Tensor, ...],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
......@@ -1786,9 +2154,10 @@ class Qwen3VLForConditionalGeneration(
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt (Containing
entire sequence).
multimodal_embeddings: Tuple of multimodal embeddings.
input_ids: (N,) All input tokens of the prompt containing
entire sequence.
multimodal_embeddings: Tuple of multimodal embeddings that
fits into the prefill chunk that is being processed.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
......@@ -1797,10 +2166,26 @@ class Qwen3VLForConditionalGeneration(
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
return self._recompute_mrope_positions(
input_ids=input_ids,
multimodal_embeddings=multimodal_embeddings,
mrope_positions=mrope_positions,
num_computed_tokens=num_computed_tokens,
image_token_id=self.config.image_token_id,
video_token_id=self.config.video_token_id,
vision_start_token_id=self.config.vision_start_token_id,
)
@staticmethod
def _recompute_mrope_positions(
input_ids: list[int],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
vision_start_token_id: int,
image_token_id: int,
video_token_id: int,
) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
# Device
device = (
multimodal_embeddings[0].device
......@@ -1811,10 +2196,21 @@ class Qwen3VLForConditionalGeneration(
# Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
mm_embeddings_pos = [
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
]
mm_embeddings_out = []
mm_embeddings_pos = []
# Strip position information from embeddings (last 5 channels)
# For Qwen3 VL, handle potentially empty frames (from unpacking)
for mm in multimodal_embeddings:
if mm.shape[0] > 0: # Only process non-empty frames
mm_embeddings_out.append(mm[:, :-5])
mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long())
else:
# Empty frame - keep as is
mm_embeddings_out.append(mm)
# Create empty position tensor with correct shape
mm_embeddings_pos.append(
torch.empty(5, 0, device=device, dtype=torch.long)
)
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,
......@@ -1828,107 +2224,14 @@ class Qwen3VLForConditionalGeneration(
return tuple(mm_embeddings_out), positions, mrope_positions_delta
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
# Pre-collect actual frame token counts for EVS mode
frame_token_counts_map = {}
for mm_feature in mm_features:
if mm_feature.modality == "video":
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
token_counts = self._get_actual_frame_token_counts(
mm_feature.mm_position, t
)
assert token_counts is not None, (
"EVS enabled but failed to extract frame token counts "
"from is_embed mask"
)
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
llm_pos_ids_list = []
st = 0
frame_counts_idx = {}
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features
):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
# Determine actual token count for this frame
base_offset = None
for feat_offset in frame_token_counts_map:
if offset >= feat_offset:
base_offset = feat_offset
if base_offset is not None:
# EVS mode: use actual token count from is_embed mask
assert base_offset in frame_token_counts_map, (
f"Found base_offset {base_offset} but not in frame_token_counts_map"
)
if base_offset not in frame_counts_idx:
frame_counts_idx[base_offset] = 0
counts = frame_token_counts_map[base_offset]
idx = frame_counts_idx[base_offset]
assert idx < len(counts), (
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
)
actual_frame_tokens = counts[idx]
frame_counts_idx[base_offset] += 1
else:
# Non-EVS mode (or image): use theoretical grid size
actual_frame_tokens = llm_grid_h * llm_grid_w
# Add text segment
text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(text_positions)
st_idx += text_len
# Add frame segment with actual token count (not theoretical)
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
# Only take the first actual_frame_tokens positions
frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
llm_pos_ids_list.append(frame_positions)
# Update st using actual token count
st = offset + actual_frame_tokens
# Handle final text segment
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
final_text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(final_text_positions)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: list[torch.Tensor] = []
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
......@@ -1936,19 +2239,20 @@ class Qwen3VLForConditionalGeneration(
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
image_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings, multimodal_input
)
multimodal_embeddings += tuple(image_embeddings)
image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings, multimodal_input
)
multimodal_embeddings.extend(image_embeddings)
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input
)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
multimodal_embeddings.extend(video_embeddings)
embeddings_tuple = tuple(multimodal_embeddings)
return embeddings_tuple
def _compute_deepstack_embeds(
self,
......@@ -2128,3 +2432,8 @@ class Qwen3VLForConditionalGeneration(
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@lru_cache
def _cached_tensor(x, device) -> torch.Tensor:
return torch.tensor(x, device=device)
......@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.tokenizers.registry import cached_tokenizer_from_config
from .interfaces import MixtureOfExperts
from .qwen3_moe import (
......@@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
......
......@@ -170,9 +170,9 @@ def recompute_mrope_positions(
multimodal_embeddings may contain zero, some or even some part of all
multimodal_embeddings for a given prompt.
Each multimodal_positions has 4 extra channels
(First 3 channels corresponds to original 3 mrope positions, last channel
is the maximum width of the media repeated). Provided multimodal_positions
Each multimodal_positions has 4 or 5 extra channels
(first 3 channels correspond to the original 3 mrope positions;
remaining channels vary by model — see below). Provided multimodal_positions
do not reflect location of media position in sequence - they are computed
like the media is in the 0-th position in the sequence.
......@@ -186,6 +186,16 @@ def recompute_mrope_positions(
Args:
input_ids: (N,) All input tokens of the prompt (entire sequence).
multimodal_positions: List of mrope positions for each media.
If a given element is of shape (4, N), it is assumed to only describe
positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL,
where each multimodal input is a contiguous chunk of embeddings.
The expected channels are [t, h, w, max_width].
If it is of shape (5, N), it is assumed to possibly describe positions for
both video / image embeddings, as well as text embeddings. This is the case
of e.g. Qwen3 VL, where each video inputs are comprised of individual
frames' embeddings, interleaved with embeddings for timestamp tokens,
and vision start / end tokens. The expected channels are
[t, h, w, is_vision_start, is_vision].
mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media.
......@@ -233,6 +243,21 @@ def recompute_mrope_positions(
# - Current prefill chunk has no vision start indexes at all
# - Vision start token appeared in previous prefill round
# - Regular case
has_video_tokens = False
num_timestamp_tokens = 0
if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0:
# mm_pos[4, :] indicates which positions are for video embeddings.
# If there are no video embeddings, skip timestamp adjustment.
has_video_tokens = torch.any(mm_pos[4, :]).item()
if has_video_tokens:
# Channel 3 flags VISION_START tokens. Timestamp tokens
# precede the first VISION_START, so its index gives us the
# exact timestamp count. This is robust even when early
# frames have all their video tokens pruned (which would
# push argmax(channel 4) far into a later frame).
first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0]
num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0
seen_vision_start_indices = vision_start_indices[
vision_start_indices < num_computed_tokens
]
......@@ -249,6 +274,18 @@ def recompute_mrope_positions(
in_the_middle_of_media = (
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
)
# For Qwen3 VL, we can be inside a media segment even before any
# video tokens appear (timestamp tokens are text). If we've passed
# the last vision_start token but haven't reached the first video
# embedding, treat this as "in the middle of media".
if (
not in_the_middle_of_media
and has_video_tokens
and num_computed_tokens > last_vision_start_token
and num_computed_tokens
<= last_vision_start_token + num_timestamp_tokens + 1
):
in_the_middle_of_media = True
if in_the_middle_of_media:
mm_embeddings_seen = (
......@@ -274,14 +311,39 @@ def recompute_mrope_positions(
mm_embeddings_seen = 0
global_mm_start = next_vision_start_token
# Offset right after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
# For Qwen3 VL, mm_pos includes timestamp tokens before vision_start
# when starting a new media. Adjust global_mm_start to point to where
# the sequence actually begins (before timestamp tokens).
adjusted_for_timestamps = False
if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens:
# NOTE: -1 is because there is a vision start token right after
# timestamp tokens before any video embeddings appear.
# Adjust global_mm_start to point to the first timestamp token
# instead of the vision_start token.
global_mm_start -= num_timestamp_tokens
adjusted_for_timestamps = True
# Offset calculation depends on whether we adjusted for timestamp tokens
if adjusted_for_timestamps:
# Start from position before the first timestamp token
base = positions[-1, global_mm_start - 1] + 1
local_start = global_mm_start + mm_embeddings_seen
else:
# Original logic: start after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
local_end = local_start + mm_pos.shape[1]
positions[:, local_start:local_end] = mm_pos[0:3] + base
# mm_pos[3, 0] is the max width of the media
offset = mm_pos[3, 0] + base
# For Qwen3 VL (5-channel), use the maximum position reached across
# all tokens (both video and text) in all dimensions (t, h, w).
# For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width.
if mm_pos.shape[0] == 5:
offset = mm_pos[0:3, :].max() + base + 1
else:
offset = mm_pos[3, 0] + base
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment