"examples/vscode:/vscode.git/clone" did not exist on "c222f47992ce0bbcd3ccbce24736e045d8689be8"
Unverified Commit 70c73df6 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

[Bugfix] Fix EVS implementation for Qwen3 VL (#33607)


Signed-off-by: default avatar2ez4bz <133824995+2ez4bz@users.noreply.github.com>
parent 9a9d4424
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import random
from dataclasses import dataclass
import pytest
import torch
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
# _get_mrope_input_positions returns CPU tensors (via torch.from_numpy).
# Ensure the default device is CPU so the rest of the test tensors match.
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
IMAGE_TOKEN_ID = 999
VIDEO_TOKEN_ID = 888
VISION_START_TOKEN_ID = 777
VISION_END_TOKEN_ID = 778
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 1
@dataclass
class DummyConfig:
image_token_id: int = IMAGE_TOKEN_ID
video_token_id: int = VIDEO_TOKEN_ID
vision_start_token_id: int = VISION_START_TOKEN_ID
vision_end_token_id: int = VISION_END_TOKEN_ID
vision_config: DummyVisionConfig = dataclasses.field(
default_factory=DummyVisionConfig
)
def make_video_embedding(
t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0
):
"""
Helper function to make a video embedding for a given video size and pruning rate.
Args:
t: Number of frames.
h: Number of rows.
w: Number of columns.
interleave_text_tokens: Tuple of minimum and maximum number of text tokens to
interleave with the video.
video_pruning_rate: Pruning rate for the video.
Returns:
Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask)
"""
unpruned_tokens_sequence = []
population = list(range(1, 100))
for _ in range(t):
num_prefix_tokens = random.randint(
interleave_text_tokens[0], interleave_text_tokens[1]
)
prefix_tokens = random.choices(population, k=num_prefix_tokens)
vision_tokens = (
[VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID]
)
unpruned_tokens_sequence.extend(prefix_tokens)
unpruned_tokens_sequence.extend(vision_tokens)
unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long)
video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID
pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined]
# Sanity check that we don't prune what should not be pruned.
assert not pruning_mask[~video_token_mask].any()
retention_mask = ~pruning_mask
pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask]
return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask
@pytest.mark.parametrize("spatial_merge_size", [1, 2])
@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]])
@pytest.mark.parametrize("num_prefix_tokens", [1, 11])
@pytest.mark.parametrize("num_suffix_tokens", [0, 7])
@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75])
@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)])
def test_match_qwen3vl_mrope_evs_on(
spatial_merge_size: int,
num_prefix_tokens: int,
grid_thw: tuple[int, int, int],
num_suffix_tokens: int,
video_pruning_rate: float,
interleave_text_tokens: tuple[int, int],
):
hf_config = DummyConfig()
hf_config.vision_config.spatial_merge_size = spatial_merge_size
t, h, w = grid_thw
population = list(range(1, 100))
prefix_tokens = random.choices(population, k=num_prefix_tokens)
suffix_tokens = random.choices(population, k=num_suffix_tokens)
video_tokens, video_tokens_pruned, retention_mask = make_video_embedding(
t,
h // spatial_merge_size,
w // spatial_merge_size,
interleave_text_tokens=interleave_text_tokens,
video_pruning_rate=video_pruning_rate,
)
assert len(video_tokens) == len(retention_mask)
input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens
input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens
whole_sequence_retention_mask = torch.cat(
[
torch.ones(len(prefix_tokens), dtype=torch.bool),
retention_mask,
torch.ones(len(suffix_tokens), dtype=torch.bool),
],
dim=0,
)
# Build the GT mrope for unpruned input.
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=len(input_tokens)),
)
expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=[mm_feature],
config=hf_config,
)
# Compute mrope for a video-only media (unpruned).
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()),
)
video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
input_tokens=video_tokens.tolist(),
mm_features=[mm_feature],
config=hf_config,
)
video_mrope = video_mrope.permute(1, 0) # [N, 3]
hidden_size = 16
is_video_embed = torch.isin(
video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long)
)
expanded_positions = torch.full(
(len(video_tokens_pruned), 5),
fill_value=-100,
device=video_mrope.device,
dtype=torch.long,
)
expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed]
expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][
~is_video_embed
]
is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID
expanded_positions[..., 3] = is_vision_start
expanded_positions[..., 4] = is_video_embed
# Check that all positions were filled, since we initialized them as negative.
assert (expanded_positions >= 0).all()
video_embeddings = torch.empty(
(len(video_tokens_pruned), hidden_size), device=video_mrope.device
)
video_embeddings = torch.cat(
[
video_embeddings,
expanded_positions.float(),
],
dim=1,
)
multimodal_embeddings = [video_embeddings]
expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask]
# Initialize computed_mrope with sequential positions for all prefix tokens
computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long)
computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[
:, 0 : len(prefix_tokens)
]
# Paranoia check that computed_mrope is wrong.
assert not torch.equal(computed_mrope, expected_mrope_masked)
_, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions(
input_ids=input_tokens_pruned,
multimodal_embeddings=multimodal_embeddings,
mrope_positions=computed_mrope,
num_computed_tokens=len(prefix_tokens),
vision_start_token_id=hf_config.vision_start_token_id,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
)
assert torch.equal(actual_mrope, expected_mrope_masked)
...@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ...@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each - second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`. when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
""" """
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
...@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ...@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
TensorShape("nv"), TensorShape("nv"),
] ]
timestamps: list[list[float]] | None = None
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
""" """
...@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ...@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each - second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`. when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
""" """
type: Literal["video_embeds"] type: Literal["video_embeds"]
...@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ...@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
torch.Tensor | None, torch.Tensor | None,
TensorShape("nv"), TensorShape("nv"),
] = None ] = None
timestamps: list[list[float]] | None = None
Qwen2_5_VLVideoInputs: TypeAlias = ( Qwen2_5_VLVideoInputs: TypeAlias = (
......
...@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory( ...@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory(
"video", video_embed_grid_sizes "video", video_embed_grid_sizes
), ),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
) )
return _qwen2vl_field_config return _qwen2vl_field_config
......
...@@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts): ...@@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder, dummy_inputs=Qwen3VLDummyInputsBuilder,
) )
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | { packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"], "in_proj_ba": ["in_proj_b", "in_proj_a"],
...@@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) ...@@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate # Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = False
multimodal_config.is_multimodal_pruning_enabled()
)
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
...@@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) ...@@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
return inputs_embeds return inputs_embeds
def recompute_mrope_positions(self, *args, **kwargs):
raise NotImplementedError(
"Qwen3.5 does not support multimodal pruning (EVS). "
"recompute_mrope_positions should never be called."
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration( ...@@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate # Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = False
multimodal_config.is_multimodal_pruning_enabled()
)
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
......
...@@ -79,6 +79,7 @@ from vllm.multimodal.inputs import ( ...@@ -79,6 +79,7 @@ from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFeatureSpec, MultiModalFeatureSpec,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalFieldElem,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
PlaceholderRange, PlaceholderRange,
...@@ -93,6 +94,8 @@ from vllm.multimodal.processing import ( ...@@ -93,6 +94,8 @@ from vllm.multimodal.processing import (
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -763,7 +766,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -763,7 +766,6 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
def _get_video_second_idx( def _get_video_second_idx(
self, self,
metadata: dict[str, Any], metadata: dict[str, Any],
out_item: MultiModalKwargsItem,
do_sample_frames: bool | None = None, do_sample_frames: bool | None = None,
sampled_fps: float | None = None, sampled_fps: float | None = None,
) -> list[int]: ) -> list[int]:
...@@ -956,6 +958,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -956,6 +958,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
if videos := mm_data.pop("videos", []): if videos := mm_data.pop("videos", []):
video_grid_thw_lst = [] video_grid_thw_lst = []
pixel_values_videos_lst = [] pixel_values_videos_lst = []
timestamps_per_video = []
for item in videos: for item in videos:
video_array, metadata = item video_array, metadata = item
...@@ -979,6 +982,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -979,6 +982,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
**{k: metadata[k] for k in metadata if k != "do_sample_frames"} **{k: metadata[k] for k in metadata if k != "do_sample_frames"}
) )
# Compute timestamps here where we have access to metadata
timestamps = self.info._get_video_second_idx(
metadata=metadata,
do_sample_frames=video_mm_kwargs["do_sample_frames"],
sampled_fps=video_mm_kwargs.get("fps"),
)
timestamps_per_video.append(timestamps)
video_mm_data = dict() video_mm_data = dict()
video_mm_data["videos"] = [[video_array]] video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]] video_mm_data["video_metadata"] = [[metadata]]
...@@ -989,6 +1000,49 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -989,6 +1000,49 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
mm_kwargs=video_mm_kwargs, mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs, tok_kwargs=tok_kwargs,
) )
merge_size = processor.video_processor.merge_size
# Get video grid info for EVS calculation.
video_grid_thw = video_outputs["video_grid_thw"]
num_frames = int(video_grid_thw[0, 0])
tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // (
merge_size**2
)
# Apply EVS if enabled.
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame_base,
num_frames=num_frames,
q=video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame.
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
select_token_id = False
else:
tokens_per_frame = [tokens_per_frame_base] * num_frames
select_token_id = True
# Generate the video replacement with EVS-adjusted token counts
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()
video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=tokens_per_frame,
timestamps=timestamps,
tokenizer=tokenizer,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
video_token_id=hf_config.video_token_id,
select_token_id=select_token_id,
)
# Convert token IDs to text for the HF processor flow
video_placeholder = tokenizer.decode(
video_repl.full, skip_special_tokens=False
)
input_ids = video_outputs.pop("input_ids") input_ids = video_outputs.pop("input_ids")
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace( prompt = prompt.replace(
...@@ -1002,6 +1056,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -1002,6 +1056,7 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
video_outputs = dict( video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst), pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_grid_thw=torch.cat(video_grid_thw_lst), video_grid_thw=torch.cat(video_grid_thw_lst),
timestamps=timestamps_per_video,
) )
else: else:
video_outputs = dict() video_outputs = dict()
...@@ -1057,60 +1112,42 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -1057,60 +1112,42 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
grid_thw = out_item["video_grid_thw"].data grid_thw = out_item["video_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor) assert isinstance(grid_thw, torch.Tensor)
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
sampled_fps = hf_processor_mm_kwargs.get("fps") sampled_fps = hf_processor_mm_kwargs.get("fps")
if is_list_of(sampled_fps, float): if is_list_of(sampled_fps, float):
sampled_fps = sampled_fps[item_idx] sampled_fps = sampled_fps[item_idx]
timestamps = self.info._get_video_second_idx(
metadata, out_item, do_sample_frames, sampled_fps
)
timestamps = out_item["timestamps"].data
assert len(timestamps) == grid_thw[0], ( assert len(timestamps) == grid_thw[0], (
f"The timestamps length({len(timestamps)}) should be equal " f"The timestamps length({len(timestamps)}) should be equal "
f"video length ({grid_thw[0]})." f"video length ({grid_thw[0]})."
) )
frames_idx_token = [ # Compute tokens per frame, with EVS support
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) num_frames = int(grid_thw[0])
for curr_time in timestamps tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length
]
tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0: if video_pruning_rate is not None and video_pruning_rate > 0.0:
total_retained = compute_retained_tokens_count( num_tokens = compute_retained_tokens_count(
tokens_per_frame, tokens_per_frame=tokens_per_frame_base,
len(frames_idx_token), num_frames=num_frames,
video_pruning_rate, q=video_pruning_rate,
) )
if len(frames_idx_token) == 0: tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
per_frame_token_counts = [] select_token_id = False
elif len(frames_idx_token) == 1: else:
per_frame_token_counts = [tokens_per_frame] tokens_per_frame = [tokens_per_frame_base] * num_frames
else: select_token_id = True
first_frame_tokens = tokens_per_frame
remaining_tokens = max(total_retained - first_frame_tokens, 0) return Qwen3VLMultiModalProcessor.get_video_repl(
base = remaining_tokens // (len(frames_idx_token) - 1) tokens_per_frame=tokens_per_frame,
remainder = remaining_tokens % (len(frames_idx_token) - 1) timestamps=timestamps,
per_frame_token_counts = [first_frame_tokens] tokenizer=tokenizer,
for frame_idx in range(1, len(frames_idx_token)): vision_start_token_id=vision_start_token_id,
extra = base + (1 if (frame_idx - 1) < remainder else 0) vision_end_token_id=vision_end_token_id,
per_frame_token_counts.append(extra) video_token_id=video_token_id,
select_token_id=select_token_id,
placeholder = [] )
for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
placeholder.extend(timestamp_tokens)
tokens_this_frame = per_frame_token_counts[
frame_idx if frame_idx < len(per_frame_token_counts) else -1
]
placeholder.extend(
[vision_start_token_id]
+ [video_token_id] * tokens_this_frame
+ [vision_end_token_id]
)
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -1127,6 +1164,69 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -1127,6 +1164,69 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
), ),
] ]
@staticmethod
def get_video_repl(
*,
tokens_per_frame: list[int],
timestamps: list[float | int],
tokenizer: TokenizerLike,
vision_start_token_id: int,
vision_end_token_id: int,
video_token_id: int,
select_token_id: bool = False,
) -> PromptUpdateDetails[list[int]]:
"""Build prompt replacement for a video in Qwen3VL format.
The replacement structure for each frame is:
timestamp_tokens + vision_start_token + video_tokens + vision_end_token
Args:
tokens_per_frame: Number of video tokens per frame (can vary per frame for
EVS).
timestamps: List of timestamps in seconds for each frame
tokenizer: Tokenizer to encode timestamp strings
vision_start_token_id: Token ID for vision start marker
vision_end_token_id: Token ID for vision end marker
video_token_id: Token ID for video content
Returns:
PromptUpdateDetails with full token sequence
"""
assert len(timestamps) == len(tokens_per_frame), (
"timestamps and tokens_per_frame must have the same length"
)
# Tokenize timestamp strings independently to avoid tokenizer merging
# tokens across boundaries.
# TODO: switch to `_seq2tokens` which has some caching.
timestamp_token_ids = [
tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False)
for timestamp in timestamps
]
# Build the full token sequence
all_token_ids = []
for frame_timestamp_ids, num_tokens in zip(
timestamp_token_ids, tokens_per_frame
):
# Add timestamp tokens
all_token_ids.extend(frame_timestamp_ids)
# Add vision tokens: vision_start + video_tokens + vision_end
all_token_ids.append(vision_start_token_id)
all_token_ids.extend([video_token_id] * num_tokens)
all_token_ids.append(vision_end_token_id)
if select_token_id:
return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)
# NOTE: we use `from_seq` instead of `select_token_id` because we want all
# tokens in the placeholder to be initially marked as candidates. Then
# in `get_input_embeddings``, we refine the mask to only replace
# `video_token_id` / `image_token_id`` positions with video/image embeddings,
# keeping text embeddings for timestamps and structural tokens.
return PromptUpdateDetails.from_seq(all_token_ids)
@support_torch_compile( @support_torch_compile(
dynamic_arg_dims={ dynamic_arg_dims={
...@@ -1280,6 +1380,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1280,6 +1380,7 @@ class Qwen3VLForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
...@@ -1419,6 +1520,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1419,6 +1520,7 @@ class Qwen3VLForConditionalGeneration(
video_embeds = kwargs.pop("video_embeds", None) video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None) video_grid_thw = kwargs.pop("video_grid_thw", None)
second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
timestamps = kwargs.pop("timestamps", None)
if pixel_values_videos is None and video_embeds is None: if pixel_values_videos is None and video_embeds is None:
return None return None
...@@ -1429,6 +1531,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1429,6 +1531,7 @@ class Qwen3VLForConditionalGeneration(
pixel_values_videos=pixel_values_videos, pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw, video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
timestamps=timestamps,
) )
if video_embeds is not None: if video_embeds is not None:
...@@ -1436,6 +1539,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1436,6 +1539,7 @@ class Qwen3VLForConditionalGeneration(
type="video_embeds", type="video_embeds",
video_embeds=video_embeds, video_embeds=video_embeds,
video_grid_thw=video_grid_thw, video_grid_thw=video_grid_thw,
timestamps=timestamps,
) )
def _process_image_input( def _process_image_input(
...@@ -1502,19 +1606,29 @@ class Qwen3VLForConditionalGeneration( ...@@ -1502,19 +1606,29 @@ class Qwen3VLForConditionalGeneration(
Returns: Returns:
Tuple of image embeddings for each image item. Tuple of image embeddings for each image item.
Resulting embeddings will have extra 4 channels for Resulting embeddings will have extra 5 channels for
computed mrope positions. computed mrope positions, consistent with video embeddings.
""" """
merge_size = self.visual.spatial_merge_size if self.is_multimodal_pruning_enabled:
grid_thw = image_input["image_grid_thw"] merge_size = self.visual.spatial_merge_size
grid_thw_list = grid_thw.tolist() grid_thw = image_input["image_grid_thw"]
image_embeds_out = [] grid_thw_list = grid_thw.tolist()
for emb, size in zip(image_embeds_split, grid_thw_list): image_embeds_out = []
positions = compute_mrope_for_media(size, merge_size).to(emb.device) for emb, size in zip(image_embeds_split, grid_thw_list):
emb = torch.cat([emb, positions], dim=1) positions = compute_mrope_for_media(size, merge_size).to(emb.device)
image_embeds_out.append(emb) positions = torch.cat(
image_embeds_split = image_embeds_out [
return tuple(image_embeds_split) positions,
torch.zeros_like(
positions[:, 0:1]
), # Dummy extra fifth channel
],
dim=1,
)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = tuple(image_embeds_out)
return image_embeds_split
def _postprocess_video_embeds_evs( def _postprocess_video_embeds_evs(
self, self,
...@@ -1531,62 +1645,218 @@ class Qwen3VLForConditionalGeneration( ...@@ -1531,62 +1645,218 @@ class Qwen3VLForConditionalGeneration(
Returns: Returns:
Tuple of video embeddings for each video item. Tuple of video embeddings for each video item.
Resulting embeddings will have extra 4 channels for Resulting embeddings will have extra 5 channels for computed mrope
computed mrope positions. positions, and whether the index corresponds to a video embedding.
""" """
grid_thw = video_input["video_grid_thw"] grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist() grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
# Cast to long to match the original code # Apply EVS to each video.
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa video_embeds_out = []
second_per_grid_ts = video_input.get("second_per_grid_ts") for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)):
if second_per_grid_ts is None: # Compute positions.
# For Qwen3-VL, second_per_grid_ts might not be available timestamps = video_input.timestamps[video_idx]
# Use default value of 1.0 for each video num_frames = len(timestamps)
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
t, h, w = size
if self.is_multimodal_pruning_enabled:
# For each video, compute retention mask using EVS.
# retention_mask: [11424].
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
# Apply retention mask.
emb = emb[retention_mask]
# Calculate the actual number of retained tokens per frame.
num_frames, rows, cols = (
t,
h // merge_size,
w // merge_size,
)
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
num_tokens_per_frame = (
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
)
else:
feature_size = emb.shape[0] // num_frames
num_tokens_per_frame = [feature_size] * num_frames
retention_mask = None
emb = self._create_final_video_embeddings(
video_embeddings=emb,
num_tokens_per_frame=num_tokens_per_frame,
timestamps=timestamps,
video_grid_thw=size,
retention_mask=retention_mask,
)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def _create_final_video_embeddings(
self,
video_embeddings: torch.Tensor,
num_tokens_per_frame: list[int],
timestamps: list[float],
video_grid_thw: list[int],
retention_mask: torch.Tensor,
) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with
text embeddings of indicator tokens.
These final embeddings contain:
- Actual video embeddings in positions corresponding to video content
- Text embeddings for indicator tokens (<img>, </img>, and
frame separation text) in their respective positions
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
# special tokens to ensure consistent tokenization regardless of
# num_tokens_per_frame values.
video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=num_tokens_per_frame,
tokenizer=self._tokenizer,
timestamps=timestamps,
vision_start_token_id=self.config.vision_start_token_id,
vision_end_token_id=self.config.vision_end_token_id,
video_token_id=self.config.video_token_id,
select_token_id=self.is_multimodal_pruning_enabled,
)
repl_token_ids = torch.tensor(video_repl.full, device=device)
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
is_video_embed = torch.isin(repl_token_ids, embed_token_id)
# Get text embeddings for indicator tokens (has only `visual_dim``).
text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)
if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=text_embeddings,
multimodal_embeddings=[video_embeddings],
is_multimodal=is_video_embed,
)
else: else:
second_per_grid_ts = second_per_grid_ts.long() deepstack_input_embeds = None
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) multimodal_embeddings = [video_embeddings]
video_embeds_out = [] merged_embeddings = _merge_multimodal_embeddings(
for emb, size, video_second_per_grid_t in zip( inputs_embeds=text_embeddings,
video_embeds_split, grid_thw_list, second_per_grid_ts multimodal_embeddings=multimodal_embeddings,
): is_multimodal=is_video_embed,
# For each video, we compute retention mask using EVS )
retention_mask = compute_retention_mask(
emb, to_concat = [merged_embeddings]
size, if deepstack_input_embeds is not None:
spatial_merge_size=self.visual.spatial_merge_size, to_concat.append(
q=self.video_pruning_rate, deepstack_input_embeds.permute(1, 0, 2).reshape(
deepstack_input_embeds.shape[1], -1
)
) )
# Debug logging for EVS pruning expanded_positions = None
logger.debug( if self.is_multimodal_pruning_enabled:
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id)
"pruning_rate=%.2f, reduction=%.1f%%)", expanded_positions = self._get_expanded_positions(
emb.shape[0], device=merged_embeddings.device,
retention_mask.sum().item(), seq_len=merged_embeddings.shape[0],
size[0], video_grid_thw=video_grid_thw,
size[1], num_tokens_per_frame=num_tokens_per_frame,
size[2], timestamps=timestamps,
self.video_pruning_rate, is_video_embed=is_video_embed,
(1 - retention_mask.float().mean().item()) * 100, is_vision_start=is_vision_start,
retention_mask=retention_mask,
) )
to_concat.append(expanded_positions)
positions = compute_mrope_for_media( final_video_embeddings = torch.cat(to_concat, dim=-1)
size,
merge_size,
tokens_per_second=tokens_per_second,
video_second_per_grid=video_second_per_grid_t.item(),
).to(emb.device)
emb = emb[retention_mask] return final_video_embeddings
positions = positions[retention_mask]
emb = torch.cat([emb, positions], dim=1) def _get_expanded_positions(
video_embeds_out.append(emb) self,
return tuple(video_embeds_out) device,
seq_len,
video_grid_thw,
num_tokens_per_frame,
timestamps,
is_video_embed,
is_vision_start,
retention_mask,
):
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
# Expand positions to match the full sequence length
# (includes both video tokens and indicator tokens)
# Shape: [full_length, 5] where positions are filled for video tokens
# and zeros for indicator tokens.
# Channel 3 flags VISION_START tokens so that
# recompute_mrope_positions can reliably count timestamp tokens
# (even when early frames have all video tokens pruned).
# Channel 4 flags video-embedding tokens.
expanded_positions = torch.zeros(
seq_len,
5, # [t_index, h_index, w_index, is_vision_start, is_video]
device=device,
dtype=torch.long,
)
_, h, w = video_grid_thw
merge_size = self.visual.spatial_merge_size
num_frames = len(num_tokens_per_frame)
unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames,
tokenizer=self._tokenizer,
timestamps=timestamps,
vision_start_token_id=self.config.vision_start_token_id,
vision_end_token_id=self.config.vision_end_token_id,
video_token_id=self.config.video_token_id,
).full
unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(video_grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)),
)
original_mrope = (
self.get_mrope_input_positions(
input_tokens=unpruned_token_ids,
mm_features=[mm_feature],
)[0]
.to(device)
.permute(1, 0)
)
full_is_video_embed = unpruned_token_ids_tensor == embed_token_id
expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][
retention_mask
]
expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed]
expanded_positions[..., 3] = is_vision_start
expanded_positions[..., 4] = is_video_embed
return expanded_positions
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {} mm_input_by_modality = {}
...@@ -1607,66 +1877,77 @@ class Qwen3VLForConditionalGeneration( ...@@ -1607,66 +1877,77 @@ class Qwen3VLForConditionalGeneration(
) )
return mm_input_by_modality return mm_input_by_modality
def iter_mm_grid_hw( @staticmethod
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] def _iter_mm_grid_hw(
) -> Iterator[tuple[int, int, int]]: input_tokens: list[int],
""" mm_features: list[MultiModalFeatureSpec],
Iterate over multimodal features and yield grid information. video_token_id: int,
vision_start_token_id: int,
For videos with EVS (Efficient Video Sampling) enabled, this function vision_end_token_id: int,
computes the offset based on the pruned token count rather than relying spatial_merge_size: int,
on input_tokens.index(), which would fail when tokens are pruned. ) -> Iterator[tuple[int, int, int, int]]:
"""Iterate over multimodal features and yield position info.
Args: Args:
input_tokens: List of token IDs in the prompt input_tokens: List of token IDs in the input sequence.
mm_features: List of multimodal feature specifications mm_features: List of multimodal feature specifications containing
image/video data and position information.
video_token_id: Token ID used for video tokens.
vision_start_token_id: Token ID marking the start of a vision sequence.
vision_end_token_id: Token ID marking the end of a vision sequence.
spatial_merge_size: Size of the spatial merge operation used to
compute logical grid dimensions from the original feature grid.
Yields: Yields:
Tuple of (offset, grid_h, grid_w) for each frame/image offset: Position of the first video/image token in the sequence.
llm_grid_h: Logical grid height (may not match actual token count with EVS).
llm_grid_w: Logical grid width (may not match actual token count with EVS).
actual_num_tokens: Actual number of video/image tokens in the placeholder.
""" """
video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset offset = mm_feature.mm_position.offset
if mm_feature.modality == "image": if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist() t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}" assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, h // spatial_merge_size, w // spatial_merge_size llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w
elif mm_feature.modality == "video": elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist() t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size llm_grid_w = w // spatial_merge_size
# Check if EVS (Efficient Video Sampling) is enabled for _ in range(t):
is_evs_enabled = ( # When EVS is enabled, some frames may have 0 video tokens in the
hasattr(self, "video_pruning_rate") # placeholder. We use `vision_start_token_id` to locate each frame
and self.video_pruning_rate is not None # since it is always present for every frame.
and self.video_pruning_rate > 0.0 # We then look for the first `video_token_id` after
) # `vision_start_token_id` and before `vision_end_token_id`.
offset = input_tokens.index(vision_start_token_id, offset)
if is_evs_enabled: vision_end_offset = input_tokens.index(vision_end_token_id, offset)
frame_offsets = self._extract_frame_offsets_from_mask(
mm_feature.mm_position, t try:
) actual_num_tokens = 0
if frame_offsets is not None: video_offset = input_tokens.index(
for rel_offset in frame_offsets: video_token_id, offset, vision_end_offset
yield offset + rel_offset, llm_grid_h, llm_grid_w )
continue # NOTE: looking at the
# `Qwen3VLMultiModalProcessor.get_video_repl` code, we can
# If EVS is enabled but mask is missing, this indicates a bug # see that we can use the below formula to get the token
# in the prompt processing pipeline. The is_embed mask should # count, since everything in between `video_offset` and
# always be present when video_pruning_rate > 0. # `vision_end_offset` is populated as `video_token_id`.
raise RuntimeError( # This saves us from manually counting the number tokens
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " # that match `video_token_id` in between.
"but is_embed mask is missing from mm_position. " actual_num_tokens += vision_end_offset - video_offset
"This indicates a bug in prompt processing." except ValueError:
) # No `video_token_id` in this frame (EVS with 0 tokens for
else: # this frame) -> use `offset + 1`` to move past
# Non-EVS mode: Use original logic with input_tokens.index() # `vision_start_token_id`.
for _ in range(t): video_offset = offset + 1
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens
offset += llm_grid_h * llm_grid_w # Move offset past this frame for next iteration.
offset = vision_end_offset + 1
else: else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}") raise ValueError(f"Unsupported modality: {mm_feature.modality}")
...@@ -1771,13 +2052,100 @@ class Qwen3VLForConditionalGeneration( ...@@ -1771,13 +2052,100 @@ class Qwen3VLForConditionalGeneration(
return [len(seg) for seg in segments] return [len(seg) for seg in segments]
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
return self._get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
config=self.config,
)
@staticmethod
def _get_mrope_input_positions(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
config: Qwen3VLConfig,
):
llm_pos_ids_list = []
st = 0
for (
offset,
llm_grid_h,
llm_grid_w,
actual_num_tokens,
) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw(
input_tokens,
mm_features,
video_token_id=config.video_token_id,
vision_start_token_id=config.vision_start_token_id,
vision_end_token_id=config.vision_end_token_id,
spatial_merge_size=config.vision_config.spatial_merge_size,
):
# Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere)
if actual_num_tokens == 0:
continue
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
# Check if this is a "lumped placeholder" (all tokens from multiple frames
# assigned to the 0-th frame - see
# `Qwen3VLMultiModalProcessor.get_video_repl`.
expected_tokens_per_frame = llm_grid_h * llm_grid_w
if actual_num_tokens > expected_tokens_per_frame:
# Lumped placeholder: create grid positions for all "logical" frames
# represented.
num_logical_frames = actual_num_tokens // expected_tokens_per_frame
remainder = actual_num_tokens % expected_tokens_per_frame
# Create positions for complete frames.
for _ in range(num_logical_frames):
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(
3, -1
)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1
text_len = 0 # No text between frames within the lump
# Handle remainder tokens if any (partial frame).
# NOTE: this should never be the case. Should we have an assert?
if remainder > 0:
# Create a partial grid - take first 'remainder' positions
full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
grid_indices = full_grid[:, :remainder]
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
else:
# Normal case: frame has exactly the expected tokens (after actual EVS
# pruning).
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + actual_num_tokens
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def recompute_mrope_positions( def recompute_mrope_positions(
self, self,
input_ids: list[int], input_ids: list[int],
multimodal_embeddings: tuple[torch.Tensor, ...], multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor, mrope_positions: torch.LongTensor,
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
""" """
Update part of input mrope positions (starting with Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed num_computed_tokens index). Original mrope_positions are computed
...@@ -1786,9 +2154,10 @@ class Qwen3VLForConditionalGeneration( ...@@ -1786,9 +2154,10 @@ class Qwen3VLForConditionalGeneration(
mrope_positions before we feed it to LLM. mrope_positions before we feed it to LLM.
Args: Args:
input_ids: (N,) All input tokens of the prompt (Containing input_ids: (N,) All input tokens of the prompt containing
entire sequence). entire sequence.
multimodal_embeddings: Tuple of multimodal embeddings. multimodal_embeddings: Tuple of multimodal embeddings that
fits into the prefill chunk that is being processed.
mrope_positions: Existing mrope positions (3, N) for entire mrope_positions: Existing mrope positions (3, N) for entire
sequence sequence
num_computed_tokens: A number of computed tokens so far. num_computed_tokens: A number of computed tokens so far.
...@@ -1797,10 +2166,26 @@ class Qwen3VLForConditionalGeneration( ...@@ -1797,10 +2166,26 @@ class Qwen3VLForConditionalGeneration(
Tuple of (multimodal_embeddings, mrope_positions, Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta). mrope_position_delta).
""" """
image_token_id = self.config.image_token_id return self._recompute_mrope_positions(
video_token_id = self.config.video_token_id input_ids=input_ids,
vision_start_token_id = self.config.vision_start_token_id multimodal_embeddings=multimodal_embeddings,
mrope_positions=mrope_positions,
num_computed_tokens=num_computed_tokens,
image_token_id=self.config.image_token_id,
video_token_id=self.config.video_token_id,
vision_start_token_id=self.config.vision_start_token_id,
)
@staticmethod
def _recompute_mrope_positions(
input_ids: list[int],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
vision_start_token_id: int,
image_token_id: int,
video_token_id: int,
) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
# Device # Device
device = ( device = (
multimodal_embeddings[0].device multimodal_embeddings[0].device
...@@ -1811,10 +2196,21 @@ class Qwen3VLForConditionalGeneration( ...@@ -1811,10 +2196,21 @@ class Qwen3VLForConditionalGeneration(
# Tensors # Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] mm_embeddings_out = []
mm_embeddings_pos = [ mm_embeddings_pos = []
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings # Strip position information from embeddings (last 5 channels)
] # For Qwen3 VL, handle potentially empty frames (from unpacking)
for mm in multimodal_embeddings:
if mm.shape[0] > 0: # Only process non-empty frames
mm_embeddings_out.append(mm[:, :-5])
mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long())
else:
# Empty frame - keep as is
mm_embeddings_out.append(mm)
# Create empty position tensor with correct shape
mm_embeddings_pos.append(
torch.empty(5, 0, device=device, dtype=torch.long)
)
positions, mrope_positions_delta = recompute_mrope_positions( positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t, input_ids_t,
...@@ -1828,107 +2224,14 @@ class Qwen3VLForConditionalGeneration( ...@@ -1828,107 +2224,14 @@ class Qwen3VLForConditionalGeneration(
return tuple(mm_embeddings_out), positions, mrope_positions_delta return tuple(mm_embeddings_out), positions, mrope_positions_delta
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
# Pre-collect actual frame token counts for EVS mode
frame_token_counts_map = {}
for mm_feature in mm_features:
if mm_feature.modality == "video":
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
token_counts = self._get_actual_frame_token_counts(
mm_feature.mm_position, t
)
assert token_counts is not None, (
"EVS enabled but failed to extract frame token counts "
"from is_embed mask"
)
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
llm_pos_ids_list = []
st = 0
frame_counts_idx = {}
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features
):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
# Determine actual token count for this frame
base_offset = None
for feat_offset in frame_token_counts_map:
if offset >= feat_offset:
base_offset = feat_offset
if base_offset is not None:
# EVS mode: use actual token count from is_embed mask
assert base_offset in frame_token_counts_map, (
f"Found base_offset {base_offset} but not in frame_token_counts_map"
)
if base_offset not in frame_counts_idx:
frame_counts_idx[base_offset] = 0
counts = frame_token_counts_map[base_offset]
idx = frame_counts_idx[base_offset]
assert idx < len(counts), (
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
)
actual_frame_tokens = counts[idx]
frame_counts_idx[base_offset] += 1
else:
# Non-EVS mode (or image): use theoretical grid size
actual_frame_tokens = llm_grid_h * llm_grid_w
# Add text segment
text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(text_positions)
st_idx += text_len
# Add frame segment with actual token count (not theoretical)
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
# Only take the first actual_frame_tokens positions
frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
llm_pos_ids_list.append(frame_positions)
# Update st using actual token count
st = offset + actual_frame_tokens
# Handle final text segment
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
final_text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(final_text_positions)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return None return None
# The result multimodal_embeddings is tuple of tensors, with each # The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video). # tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = () multimodal_embeddings: list[torch.Tensor] = []
# NOTE: It is important to iterate over the keys in this dictionary # NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities. # to preserve the order of the modalities.
...@@ -1936,19 +2239,20 @@ class Qwen3VLForConditionalGeneration( ...@@ -1936,19 +2239,20 @@ class Qwen3VLForConditionalGeneration(
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
image_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled: image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings = self._postprocess_image_embeds_evs( image_embeddings, multimodal_input
image_embeddings, multimodal_input )
) multimodal_embeddings.extend(image_embeddings)
multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled: if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs( video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input video_embeddings, multimodal_input
) )
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings.extend(video_embeddings)
return multimodal_embeddings
embeddings_tuple = tuple(multimodal_embeddings)
return embeddings_tuple
def _compute_deepstack_embeds( def _compute_deepstack_embeds(
self, self,
...@@ -2128,3 +2432,8 @@ class Qwen3VLForConditionalGeneration( ...@@ -2128,3 +2432,8 @@ class Qwen3VLForConditionalGeneration(
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2 return num_vision_tokens // merge_size**2
@lru_cache
def _cached_tensor(x, device) -> torch.Tensor:
return torch.tensor(x, device=device)
...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers.registry import cached_tokenizer_from_config
from .interfaces import MixtureOfExperts from .interfaces import MixtureOfExperts
from .qwen3_moe import ( from .qwen3_moe import (
...@@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
......
...@@ -170,9 +170,9 @@ def recompute_mrope_positions( ...@@ -170,9 +170,9 @@ def recompute_mrope_positions(
multimodal_embeddings may contain zero, some or even some part of all multimodal_embeddings may contain zero, some or even some part of all
multimodal_embeddings for a given prompt. multimodal_embeddings for a given prompt.
Each multimodal_positions has 4 extra channels Each multimodal_positions has 4 or 5 extra channels
(First 3 channels corresponds to original 3 mrope positions, last channel (first 3 channels correspond to the original 3 mrope positions;
is the maximum width of the media repeated). Provided multimodal_positions remaining channels vary by model — see below). Provided multimodal_positions
do not reflect location of media position in sequence - they are computed do not reflect location of media position in sequence - they are computed
like the media is in the 0-th position in the sequence. like the media is in the 0-th position in the sequence.
...@@ -186,6 +186,16 @@ def recompute_mrope_positions( ...@@ -186,6 +186,16 @@ def recompute_mrope_positions(
Args: Args:
input_ids: (N,) All input tokens of the prompt (entire sequence). input_ids: (N,) All input tokens of the prompt (entire sequence).
multimodal_positions: List of mrope positions for each media. multimodal_positions: List of mrope positions for each media.
If a given element is of shape (4, N), it is assumed to only describe
positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL,
where each multimodal input is a contiguous chunk of embeddings.
The expected channels are [t, h, w, max_width].
If it is of shape (5, N), it is assumed to possibly describe positions for
both video / image embeddings, as well as text embeddings. This is the case
of e.g. Qwen3 VL, where each video inputs are comprised of individual
frames' embeddings, interleaved with embeddings for timestamp tokens,
and vision start / end tokens. The expected channels are
[t, h, w, is_vision_start, is_vision].
mrope_positions: Existing mrope positions (4, N) for entire sequence. mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far. num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media. vision_start_token_id: Token indicating start of vision media.
...@@ -233,6 +243,21 @@ def recompute_mrope_positions( ...@@ -233,6 +243,21 @@ def recompute_mrope_positions(
# - Current prefill chunk has no vision start indexes at all # - Current prefill chunk has no vision start indexes at all
# - Vision start token appeared in previous prefill round # - Vision start token appeared in previous prefill round
# - Regular case # - Regular case
has_video_tokens = False
num_timestamp_tokens = 0
if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0:
# mm_pos[4, :] indicates which positions are for video embeddings.
# If there are no video embeddings, skip timestamp adjustment.
has_video_tokens = torch.any(mm_pos[4, :]).item()
if has_video_tokens:
# Channel 3 flags VISION_START tokens. Timestamp tokens
# precede the first VISION_START, so its index gives us the
# exact timestamp count. This is robust even when early
# frames have all their video tokens pruned (which would
# push argmax(channel 4) far into a later frame).
first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0]
num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0
seen_vision_start_indices = vision_start_indices[ seen_vision_start_indices = vision_start_indices[
vision_start_indices < num_computed_tokens vision_start_indices < num_computed_tokens
] ]
...@@ -249,6 +274,18 @@ def recompute_mrope_positions( ...@@ -249,6 +274,18 @@ def recompute_mrope_positions(
in_the_middle_of_media = ( in_the_middle_of_media = (
seen_mm_tokens > seem_mm_tokens_before_last_vision_start seen_mm_tokens > seem_mm_tokens_before_last_vision_start
) )
# For Qwen3 VL, we can be inside a media segment even before any
# video tokens appear (timestamp tokens are text). If we've passed
# the last vision_start token but haven't reached the first video
# embedding, treat this as "in the middle of media".
if (
not in_the_middle_of_media
and has_video_tokens
and num_computed_tokens > last_vision_start_token
and num_computed_tokens
<= last_vision_start_token + num_timestamp_tokens + 1
):
in_the_middle_of_media = True
if in_the_middle_of_media: if in_the_middle_of_media:
mm_embeddings_seen = ( mm_embeddings_seen = (
...@@ -274,14 +311,39 @@ def recompute_mrope_positions( ...@@ -274,14 +311,39 @@ def recompute_mrope_positions(
mm_embeddings_seen = 0 mm_embeddings_seen = 0
global_mm_start = next_vision_start_token global_mm_start = next_vision_start_token
# Offset right after vision_start_token # For Qwen3 VL, mm_pos includes timestamp tokens before vision_start
base = positions[-1, global_mm_start] + 1 # when starting a new media. Adjust global_mm_start to point to where
local_start = global_mm_start + 1 + mm_embeddings_seen # the sequence actually begins (before timestamp tokens).
adjusted_for_timestamps = False
if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens:
# NOTE: -1 is because there is a vision start token right after
# timestamp tokens before any video embeddings appear.
# Adjust global_mm_start to point to the first timestamp token
# instead of the vision_start token.
global_mm_start -= num_timestamp_tokens
adjusted_for_timestamps = True
# Offset calculation depends on whether we adjusted for timestamp tokens
if adjusted_for_timestamps:
# Start from position before the first timestamp token
base = positions[-1, global_mm_start - 1] + 1
local_start = global_mm_start + mm_embeddings_seen
else:
# Original logic: start after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
local_end = local_start + mm_pos.shape[1] local_end = local_start + mm_pos.shape[1]
positions[:, local_start:local_end] = mm_pos[0:3] + base positions[:, local_start:local_end] = mm_pos[0:3] + base
# mm_pos[3, 0] is the max width of the media # For Qwen3 VL (5-channel), use the maximum position reached across
offset = mm_pos[3, 0] + base # all tokens (both video and text) in all dimensions (t, h, w).
# For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width.
if mm_pos.shape[0] == 5:
offset = mm_pos[0:3, :].max() + base + 1
else:
offset = mm_pos[3, 0] + base
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0) text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment