"docs/vscode:/vscode.git/clone" did not exist on "85fee74b337522f7e0807fc100b9e00682ff45e1"
Unverified Commit 70c73df6 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

[Bugfix] Fix EVS implementation for Qwen3 VL (#33607)


Signed-off-by: default avatar2ez4bz <133824995+2ez4bz@users.noreply.github.com>
parent 9a9d4424
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import random
from dataclasses import dataclass
import pytest
import torch
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
# _get_mrope_input_positions returns CPU tensors (via torch.from_numpy).
# Ensure the default device is CPU so the rest of the test tensors match.
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
IMAGE_TOKEN_ID = 999
VIDEO_TOKEN_ID = 888
VISION_START_TOKEN_ID = 777
VISION_END_TOKEN_ID = 778
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 1
@dataclass
class DummyConfig:
image_token_id: int = IMAGE_TOKEN_ID
video_token_id: int = VIDEO_TOKEN_ID
vision_start_token_id: int = VISION_START_TOKEN_ID
vision_end_token_id: int = VISION_END_TOKEN_ID
vision_config: DummyVisionConfig = dataclasses.field(
default_factory=DummyVisionConfig
)
def make_video_embedding(
t, h, w, interleave_text_tokens: tuple[int, int], video_pruning_rate: float = 0.0
):
"""
Helper function to make a video embedding for a given video size and pruning rate.
Args:
t: Number of frames.
h: Number of rows.
w: Number of columns.
interleave_text_tokens: Tuple of minimum and maximum number of text tokens to
interleave with the video.
video_pruning_rate: Pruning rate for the video.
Returns:
Tuple of (unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask)
"""
unpruned_tokens_sequence = []
population = list(range(1, 100))
for _ in range(t):
num_prefix_tokens = random.randint(
interleave_text_tokens[0], interleave_text_tokens[1]
)
prefix_tokens = random.choices(population, k=num_prefix_tokens)
vision_tokens = (
[VISION_START_TOKEN_ID] + [VIDEO_TOKEN_ID] * h * w + [VISION_END_TOKEN_ID]
)
unpruned_tokens_sequence.extend(prefix_tokens)
unpruned_tokens_sequence.extend(vision_tokens)
unpruned_tokens_sequence = torch.tensor(unpruned_tokens_sequence, dtype=torch.long)
video_token_mask = unpruned_tokens_sequence == VIDEO_TOKEN_ID
pruning_mask = torch.bernoulli(video_token_mask.float() * video_pruning_rate).bool() # type: ignore[attr-defined]
# Sanity check that we don't prune what should not be pruned.
assert not pruning_mask[~video_token_mask].any()
retention_mask = ~pruning_mask
pruned_tokens_sequence = unpruned_tokens_sequence[retention_mask]
return unpruned_tokens_sequence, pruned_tokens_sequence, retention_mask
@pytest.mark.parametrize("spatial_merge_size", [1, 2])
@pytest.mark.parametrize("grid_thw", [[3, 8, 7], [128, 10, 12]])
@pytest.mark.parametrize("num_prefix_tokens", [1, 11])
@pytest.mark.parametrize("num_suffix_tokens", [0, 7])
@pytest.mark.parametrize("video_pruning_rate", [0, 0.25, 0.75])
@pytest.mark.parametrize("interleave_text_tokens", [(0, 0), (1, 4)])
def test_match_qwen3vl_mrope_evs_on(
spatial_merge_size: int,
num_prefix_tokens: int,
grid_thw: tuple[int, int, int],
num_suffix_tokens: int,
video_pruning_rate: float,
interleave_text_tokens: tuple[int, int],
):
hf_config = DummyConfig()
hf_config.vision_config.spatial_merge_size = spatial_merge_size
t, h, w = grid_thw
population = list(range(1, 100))
prefix_tokens = random.choices(population, k=num_prefix_tokens)
suffix_tokens = random.choices(population, k=num_suffix_tokens)
video_tokens, video_tokens_pruned, retention_mask = make_video_embedding(
t,
h // spatial_merge_size,
w // spatial_merge_size,
interleave_text_tokens=interleave_text_tokens,
video_pruning_rate=video_pruning_rate,
)
assert len(video_tokens) == len(retention_mask)
input_tokens = prefix_tokens + video_tokens.tolist() + suffix_tokens
input_tokens_pruned = prefix_tokens + video_tokens_pruned.tolist() + suffix_tokens
whole_sequence_retention_mask = torch.cat(
[
torch.ones(len(prefix_tokens), dtype=torch.bool),
retention_mask,
torch.ones(len(suffix_tokens), dtype=torch.bool),
],
dim=0,
)
# Build the GT mrope for unpruned input.
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=len(input_tokens)),
)
expected_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=[mm_feature],
config=hf_config,
)
# Compute mrope for a video-only media (unpruned).
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=video_tokens.numel()),
)
video_mrope, _ = Qwen3VLForConditionalGeneration._get_mrope_input_positions(
input_tokens=video_tokens.tolist(),
mm_features=[mm_feature],
config=hf_config,
)
video_mrope = video_mrope.permute(1, 0) # [N, 3]
hidden_size = 16
is_video_embed = torch.isin(
video_tokens_pruned, torch.tensor([VIDEO_TOKEN_ID], dtype=torch.long)
)
expanded_positions = torch.full(
(len(video_tokens_pruned), 5),
fill_value=-100,
device=video_mrope.device,
dtype=torch.long,
)
expanded_positions[is_video_embed, :3] = video_mrope[retention_mask][is_video_embed]
expanded_positions[~is_video_embed, :3] = video_mrope[retention_mask][
~is_video_embed
]
is_vision_start = video_tokens_pruned == VISION_START_TOKEN_ID
expanded_positions[..., 3] = is_vision_start
expanded_positions[..., 4] = is_video_embed
# Check that all positions were filled, since we initialized them as negative.
assert (expanded_positions >= 0).all()
video_embeddings = torch.empty(
(len(video_tokens_pruned), hidden_size), device=video_mrope.device
)
video_embeddings = torch.cat(
[
video_embeddings,
expanded_positions.float(),
],
dim=1,
)
multimodal_embeddings = [video_embeddings]
expected_mrope_masked = expected_mrope[:, whole_sequence_retention_mask]
# Initialize computed_mrope with sequential positions for all prefix tokens
computed_mrope = torch.empty((3, len(input_tokens_pruned)), dtype=torch.long)
computed_mrope[:, 0 : len(prefix_tokens)] = expected_mrope[
:, 0 : len(prefix_tokens)
]
# Paranoia check that computed_mrope is wrong.
assert not torch.equal(computed_mrope, expected_mrope_masked)
_, actual_mrope, _ = Qwen3VLForConditionalGeneration._recompute_mrope_positions(
input_ids=input_tokens_pruned,
multimodal_embeddings=multimodal_embeddings,
mrope_positions=computed_mrope,
num_computed_tokens=len(prefix_tokens),
vision_start_token_id=hf_config.vision_start_token_id,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
)
assert torch.equal(actual_mrope, expected_mrope_masked)
...@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ...@@ -195,6 +195,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each - second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`. when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
""" """
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
...@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ...@@ -214,6 +216,8 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema):
TensorShape("nv"), TensorShape("nv"),
] ]
timestamps: list[list[float]] | None = None
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
""" """
...@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ...@@ -232,6 +236,8 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
- second_per_grid_ts: The video time interval (in seconds) for each - second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`. when `videos` is not `None`.
- timestamps: List of timestamp values (in seconds) for each frame
after merging. Length equals the temporal dimension after merging.
""" """
type: Literal["video_embeds"] type: Literal["video_embeds"]
...@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ...@@ -250,6 +256,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
torch.Tensor | None, torch.Tensor | None,
TensorShape("nv"), TensorShape("nv"),
] = None ] = None
timestamps: list[list[float]] | None = None
Qwen2_5_VLVideoInputs: TypeAlias = ( Qwen2_5_VLVideoInputs: TypeAlias = (
......
...@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory( ...@@ -755,6 +755,7 @@ def _create_qwen2vl_field_factory(
"video", video_embed_grid_sizes "video", video_embed_grid_sizes
), ),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
timestamps=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
) )
return _qwen2vl_field_config return _qwen2vl_field_config
......
...@@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts): ...@@ -628,6 +628,9 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs=Qwen3VLDummyInputsBuilder, dummy_inputs=Qwen3VLDummyInputsBuilder,
) )
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | { packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"], "in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"], "in_proj_ba": ["in_proj_b", "in_proj_a"],
...@@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) ...@@ -643,10 +646,8 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate # Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = False
multimodal_config.is_multimodal_pruning_enabled()
)
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
...@@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) ...@@ -693,6 +694,12 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
return inputs_embeds return inputs_embeds
def recompute_mrope_positions(self, *args, **kwargs):
raise NotImplementedError(
"Qwen3.5 does not support multimodal pruning (EVS). "
"recompute_mrope_positions should never be called."
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration( ...@@ -851,10 +858,8 @@ class Qwen3_5MoeForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate # Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = False
multimodal_config.is_multimodal_pruning_enabled()
)
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
......
This diff is collapsed.
...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers.registry import cached_tokenizer_from_config
from .interfaces import MixtureOfExperts from .interfaces import MixtureOfExperts
from .qwen3_moe import ( from .qwen3_moe import (
...@@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration( ...@@ -415,6 +416,7 @@ class Qwen3VLMoeForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
......
...@@ -170,9 +170,9 @@ def recompute_mrope_positions( ...@@ -170,9 +170,9 @@ def recompute_mrope_positions(
multimodal_embeddings may contain zero, some or even some part of all multimodal_embeddings may contain zero, some or even some part of all
multimodal_embeddings for a given prompt. multimodal_embeddings for a given prompt.
Each multimodal_positions has 4 extra channels Each multimodal_positions has 4 or 5 extra channels
(First 3 channels corresponds to original 3 mrope positions, last channel (first 3 channels correspond to the original 3 mrope positions;
is the maximum width of the media repeated). Provided multimodal_positions remaining channels vary by model — see below). Provided multimodal_positions
do not reflect location of media position in sequence - they are computed do not reflect location of media position in sequence - they are computed
like the media is in the 0-th position in the sequence. like the media is in the 0-th position in the sequence.
...@@ -186,6 +186,16 @@ def recompute_mrope_positions( ...@@ -186,6 +186,16 @@ def recompute_mrope_positions(
Args: Args:
input_ids: (N,) All input tokens of the prompt (entire sequence). input_ids: (N,) All input tokens of the prompt (entire sequence).
multimodal_positions: List of mrope positions for each media. multimodal_positions: List of mrope positions for each media.
If a given element is of shape (4, N), it is assumed to only describe
positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL,
where each multimodal input is a contiguous chunk of embeddings.
The expected channels are [t, h, w, max_width].
If it is of shape (5, N), it is assumed to possibly describe positions for
both video / image embeddings, as well as text embeddings. This is the case
of e.g. Qwen3 VL, where each video inputs are comprised of individual
frames' embeddings, interleaved with embeddings for timestamp tokens,
and vision start / end tokens. The expected channels are
[t, h, w, is_vision_start, is_vision].
mrope_positions: Existing mrope positions (4, N) for entire sequence. mrope_positions: Existing mrope positions (4, N) for entire sequence.
num_computed_tokens: A number of computed tokens so far. num_computed_tokens: A number of computed tokens so far.
vision_start_token_id: Token indicating start of vision media. vision_start_token_id: Token indicating start of vision media.
...@@ -233,6 +243,21 @@ def recompute_mrope_positions( ...@@ -233,6 +243,21 @@ def recompute_mrope_positions(
# - Current prefill chunk has no vision start indexes at all # - Current prefill chunk has no vision start indexes at all
# - Vision start token appeared in previous prefill round # - Vision start token appeared in previous prefill round
# - Regular case # - Regular case
has_video_tokens = False
num_timestamp_tokens = 0
if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0:
# mm_pos[4, :] indicates which positions are for video embeddings.
# If there are no video embeddings, skip timestamp adjustment.
has_video_tokens = torch.any(mm_pos[4, :]).item()
if has_video_tokens:
# Channel 3 flags VISION_START tokens. Timestamp tokens
# precede the first VISION_START, so its index gives us the
# exact timestamp count. This is robust even when early
# frames have all their video tokens pruned (which would
# push argmax(channel 4) far into a later frame).
first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0]
num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0
seen_vision_start_indices = vision_start_indices[ seen_vision_start_indices = vision_start_indices[
vision_start_indices < num_computed_tokens vision_start_indices < num_computed_tokens
] ]
...@@ -249,6 +274,18 @@ def recompute_mrope_positions( ...@@ -249,6 +274,18 @@ def recompute_mrope_positions(
in_the_middle_of_media = ( in_the_middle_of_media = (
seen_mm_tokens > seem_mm_tokens_before_last_vision_start seen_mm_tokens > seem_mm_tokens_before_last_vision_start
) )
# For Qwen3 VL, we can be inside a media segment even before any
# video tokens appear (timestamp tokens are text). If we've passed
# the last vision_start token but haven't reached the first video
# embedding, treat this as "in the middle of media".
if (
not in_the_middle_of_media
and has_video_tokens
and num_computed_tokens > last_vision_start_token
and num_computed_tokens
<= last_vision_start_token + num_timestamp_tokens + 1
):
in_the_middle_of_media = True
if in_the_middle_of_media: if in_the_middle_of_media:
mm_embeddings_seen = ( mm_embeddings_seen = (
...@@ -274,14 +311,39 @@ def recompute_mrope_positions( ...@@ -274,14 +311,39 @@ def recompute_mrope_positions(
mm_embeddings_seen = 0 mm_embeddings_seen = 0
global_mm_start = next_vision_start_token global_mm_start = next_vision_start_token
# Offset right after vision_start_token # For Qwen3 VL, mm_pos includes timestamp tokens before vision_start
base = positions[-1, global_mm_start] + 1 # when starting a new media. Adjust global_mm_start to point to where
local_start = global_mm_start + 1 + mm_embeddings_seen # the sequence actually begins (before timestamp tokens).
adjusted_for_timestamps = False
if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens:
# NOTE: -1 is because there is a vision start token right after
# timestamp tokens before any video embeddings appear.
# Adjust global_mm_start to point to the first timestamp token
# instead of the vision_start token.
global_mm_start -= num_timestamp_tokens
adjusted_for_timestamps = True
# Offset calculation depends on whether we adjusted for timestamp tokens
if adjusted_for_timestamps:
# Start from position before the first timestamp token
base = positions[-1, global_mm_start - 1] + 1
local_start = global_mm_start + mm_embeddings_seen
else:
# Original logic: start after vision_start_token
base = positions[-1, global_mm_start] + 1
local_start = global_mm_start + 1 + mm_embeddings_seen
local_end = local_start + mm_pos.shape[1] local_end = local_start + mm_pos.shape[1]
positions[:, local_start:local_end] = mm_pos[0:3] + base positions[:, local_start:local_end] = mm_pos[0:3] + base
# mm_pos[3, 0] is the max width of the media # For Qwen3 VL (5-channel), use the maximum position reached across
offset = mm_pos[3, 0] + base # all tokens (both video and text) in all dimensions (t, h, w).
# For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width.
if mm_pos.shape[0] == 5:
offset = mm_pos[0:3, :].max() + base + 1
else:
offset = mm_pos[3, 0] + base
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0) text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment