# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2025 The vLLM team. # Copyright 2025 The Qwen Team. # Copyright 2025 The HuggingFace Inc. team. # All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from functools import lru_cache, partial from itertools import islice from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( smart_resize as image_smart_resize, ) from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor from transformers.models.qwen3_vl.configuration_qwen3_vl import ( Qwen3VLConfig, Qwen3VLVisionConfig, ) from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( smart_resize as video_smart_resize, ) from transformers.video_utils import VideoMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group, parallel_state from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.attention.mm_encoder_attention import ( MMEncoderAttention, ) from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.evs import ( compute_mrope_for_media, compute_retained_tokens_count, compute_retention_mask, recompute_mrope_positions, ) from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, PlaceholderRange, VideoItem, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( BaseDummyInputsBuilder, BaseMultiModalProcessor, PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors from vllm.tokenizers.protocol import TokenizerLike from vllm.tokenizers.registry import cached_tokenizer_from_config from vllm.utils.collection_utils import is_list_of from vllm.utils.math_utils import round_up from .interfaces import ( MultiModalEmbeddings, SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsMultiModalPruning, SupportsPP, _require_is_multimodal, ) from .qwen2_5_vl import ( Qwen2_5_VisionAttention, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs, ) from .qwen2_vl import ( Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo, _create_qwen2vl_field_factory, ) from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import ( AutoWeightsLoader, PPMissingLayer, WeightsMapper, _merge_multimodal_embeddings, maybe_prefix, ) from .vision import ( get_vit_attn_backend, is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model, ) logger = init_logger(__name__) # We use 2048 dummy video frames that would generate vision embeddings # of the maximum size. DUMMY_VIDEO_NUM_FRAMES = 2048 class Qwen3_VisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, hidden_size: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) self.proj = Conv3dLayer( in_channels, hidden_size, kernel_size=kernel_size, stride=kernel_size, bias=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.hidden_size) return x class Qwen3_VisionMLP(nn.Module): def __init__( self, in_features: int, hidden_features: int, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() use_data_parallel = is_vit_use_data_parallel() self.linear_fc1 = ColumnParallelLinear( in_features, hidden_features, bias=bias, quant_config=quant_config, return_bias=False, prefix=f"{prefix}.linear_fc1", disable_tp=use_data_parallel, ) self.linear_fc2 = RowParallelLinear( hidden_features, in_features, bias=bias, quant_config=quant_config, return_bias=False, prefix=f"{prefix}.linear_fc2", disable_tp=use_data_parallel, ) self.act_fn = act_fn def forward(self, x: torch.Tensor): mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) return mlp_output class Qwen3_VisionBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) self.attn = Qwen2_5_VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", ) self.mlp = Qwen3_VisionMLP( dim, mlp_hidden_dim, act_fn=act_fn, bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: x = x + self.attn( self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, sequence_lengths=sequence_lengths, ) x = x + self.mlp(self.norm2(x)) return x class Qwen3_VisionPatchMerger(nn.Module): def __init__( self, d_model: int, context_dim: int, norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, use_postshuffle_norm: bool = False, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() use_data_parallel = is_vit_use_data_parallel() self.hidden_size = context_dim * (spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm if self.use_postshuffle_norm: context_dim = self.hidden_size if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm = norm_layer(context_dim) self.linear_fc1 = ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_fc1", disable_tp=use_data_parallel, ) self.act_fn = nn.GELU() self.linear_fc2 = RowParallelLinear( self.hidden_size, d_model, bias=True, quant_config=quant_config, prefix=f"{prefix}.linear_fc2", disable_tp=use_data_parallel, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_postshuffle_norm: x = self.norm(x.view(-1, self.hidden_size)) else: x = self.norm(x).view(-1, self.hidden_size) x_parallel, _ = self.linear_fc1(x) x_parallel = self.act_fn(x_parallel) out, _ = self.linear_fc2(x_parallel) return out class Qwen3_VisionTransformer(nn.Module): def __init__( self, vision_config: Qwen3VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads self.num_position_embeddings = vision_config.num_position_embeddings self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = ( vision_config.deepstack_visual_indexes if hasattr(vision_config, "deepstack_visual_indexes") else [] ) self.num_grid_per_side = int(self.num_position_embeddings**0.5) use_data_parallel = is_vit_use_data_parallel() self.tp_size = ( 1 if use_data_parallel else parallel_state.get_tensor_model_parallel_world_size() ) # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack self.out_hidden_size = vision_config.out_hidden_size * ( 1 + len(self.deepstack_visual_indexes) ) self.patch_embed = Qwen3_VisionPatchEmbed( patch_size=self.patch_size, temporal_patch_size=self.temporal_patch_size, in_channels=vision_config.in_channels, hidden_size=self.hidden_size, ) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = get_rope( head_size=head_dim, max_position=8192, is_neox_style=True, rope_parameters={"partial_rotary_factor": 0.5}, ) self.merger = Qwen3_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, norm_layer=norm_layer, spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger", ) self.deepstack_merger_list = nn.ModuleList( [ Qwen3_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, spatial_merge_size=self.spatial_merge_size, use_postshuffle_norm=True, norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", ) for layer_idx in range(len(self.deepstack_visual_indexes)) ] ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), ) self.blocks = nn.ModuleList( [ Qwen3_VisionBlock( dim=self.hidden_size, num_heads=self.num_heads, mlp_hidden_dim=vision_config.intermediate_size, act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(vision_config.depth) ] ) @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype @property def device(self) -> torch.device: return self.patch_embed.proj.weight.device @staticmethod @lru_cache(maxsize=1024) def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) h_div = h // spatial_merge_size w_div = w // spatial_merge_size hpos_ids = hpos_ids.reshape( h_div, spatial_merge_size, w_div, spatial_merge_size, ) hpos_ids = hpos_ids.transpose(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) wpos_ids = wpos_ids.reshape( h_div, spatial_merge_size, w_div, spatial_merge_size, ) wpos_ids = wpos_ids.transpose(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) def rot_pos_emb(self, grid_thw: list[list[int]]): max_grid_size = max(max(h, w) for _, h, w in grid_thw) pos_ids = [ self.rot_pos_ids(h, w, self.spatial_merge_size) if t == 1 else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) for t, h, w in grid_thw ] pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) cos_combined = cos[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size hidden_dim = self.pos_embed.embedding_dim outputs = [] for t, h, w in grid_thw: h_idxs = torch.linspace( 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device ) w_idxs = torch.linspace( 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device ) h_floor = h_idxs.to(torch.long) w_floor = w_idxs.to(torch.long) h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) dh = h_idxs - h_floor dw = w_idxs - w_floor # Create meshgrid view for all h, w vars dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") # original computation of weights # w00 = (1 - dh_grid) * (1 - dw_grid) # w01 = (1 - dh_grid) * dw_grid # w10 = dh_grid * (1 - dw_grid) # w11 = dh_grid * dw_grid # we reuse w11 here to avoid duplicate # dh_grid * dw_grid computation w11 = dh_grid * dw_grid w10 = dh_grid - w11 w01 = dw_grid - w11 w00 = 1 - dh_grid - w01 h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) h_grid_idx = h_grid * num_grid_per_side indices = (h_grid_idx + w_grid).reshape(4, -1) weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) weights = weights.to(dtype=self.dtype) embeds = self.pos_embed(indices) embeds *= weights combined = embeds.sum(dim=0) combined = combined.reshape( h // m_size, m_size, w // m_size, m_size, hidden_dim ) combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0) def forward( self, x: torch.Tensor, grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) hidden_states = self.patch_embed(hidden_states) if isinstance(grid_thw, list): grid_thw_list = grid_thw grid_thw = np.array(grid_thw, dtype=np.int32) else: grid_thw_list = grid_thw.tolist() grid_thw = grid_thw.numpy() pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( axis=0, dtype=np.int32 ) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens( self.attn_backend, cu_seqlens, self.device ) max_seqlen = torch.tensor( MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens), dtype=torch.int32, ) cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( self.attn_backend, cu_seqlens, self.hidden_size, self.tp_size, self.device, ) hidden_states = hidden_states.unsqueeze(1) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, sequence_lengths=sequence_lengths, ) if layer_num in self.deepstack_visual_indexes: deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( hidden_states ) deepstack_feature_lists.append(deepstack_feature) hidden_states = self.merger(hidden_states) hidden_states = torch.cat( [hidden_states] + deepstack_feature_lists, dim=1 ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), ("attn.qkv.", "attn.k.", "k"), ("attn.qkv.", "attn.v.", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(Qwen3VLConfig) def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: return self.ctx.get_hf_processor( Qwen3VLProcessor, use_fast=kwargs.pop("use_fast", True), **kwargs, ) def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast: return self.get_hf_processor(**kwargs).image_processor def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor: return self.get_hf_processor(**kwargs).video_processor def get_data_parser(self): return Qwen2VLMultiModalDataParser( self.get_hf_config().vision_config.spatial_merge_size, video_needs_metadata=True, expected_hidden_size=self._get_expected_hidden_size(), ) def _get_vision_info( self, *, image_width: int, image_height: int, num_frames: int = 2, do_resize: bool = True, image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor, mm_kwargs: Mapping[str, object], ) -> tuple[ImageSize, int]: is_video = isinstance(image_processor, Qwen3VLVideoProcessor) hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = vision_config.temporal_patch_size mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs) size = image_processor.size if override_size := mm_kwargs.get("size"): size = size | override_size if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None: size = size | {"shortest_edge": override_min_pixels} if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None: size = size | {"longest_edge": override_max_pixels} if do_resize: if is_video: smart_resize = video_smart_resize extra_kwargs = { "num_frames": num_frames, "temporal_factor": temporal_patch_size, } else: smart_resize = image_smart_resize extra_kwargs = {} resized_height, resized_width = smart_resize( height=image_height, width=image_width, factor=patch_size * merge_size, min_pixels=size["shortest_edge"], max_pixels=size["longest_edge"], **extra_kwargs, ) preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = round_up(num_frames, temporal_patch_size) grid_t = max(padded_num_frames // temporal_patch_size, 1) grid_h = preprocessed_size.height // patch_size grid_w = preprocessed_size.width // patch_size num_patches = grid_t * grid_h * grid_w num_vision_tokens = num_patches // (merge_size**2) return preprocessed_size, num_vision_tokens def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int: return super()._get_max_video_frames( max_tokens, start_num_frames=start_num_frames ) def get_num_frames_with_most_features( self, seq_len: int, mm_counts: Mapping[str, int], ) -> int: return super().get_num_frames_with_most_features( seq_len, mm_counts, max_frames_per_video=DUMMY_VIDEO_NUM_FRAMES ) def get_max_video_tokens( self, seq_len: int, mm_counts: Mapping[str, int], ) -> int: video_processor = self.get_video_processor() mm_kwargs = self.ctx.get_merged_mm_kwargs({}) video_size = mm_kwargs.get("size", video_processor.size) temporal_patch_size = mm_kwargs.get( "temporal_patch_size", video_processor.temporal_patch_size ) # video_max_pixels contains the temporal compression factor, # so we divide by 2 to get the maximum number of image pixels. video_max_pixels = video_size["longest_edge"] target_width, target_height = self.get_image_size_with_most_features( max_pixels=video_max_pixels // temporal_patch_size ) num_video_soft_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=2, image_processor=video_processor, mm_kwargs={}, ) return num_video_soft_tokens def _calculate_timestamps( self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int ): if not isinstance(indices, list): indices = indices.tolist() if len(indices) % merge_size != 0: # don't update metadata's frames_indices directly indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size) timestamps = [idx / video_fps for idx in indices] timestamps = [ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size) ] return timestamps def _get_video_second_idx( self, metadata: dict[str, Any], do_sample_frames: bool | None = None, sampled_fps: float | None = None, sampled_num_frames: int | None = None, ) -> list[int]: video_processor = self.get_video_processor() temporal_patch_size = video_processor.temporal_patch_size indices = metadata["frames_indices"] # metadata["fps"] refers to the true fps of the input video. video_fps = metadata["fps"] if do_sample_frames is None: do_sample_frames = metadata.get("do_sample_frames", False) # If video frames are sampled in HF processor (instead of vLLM # video loader), we need to re-calculate the indices from original # metadata. if do_sample_frames: total_num_frames = metadata["total_num_frames"] # When num_frames is explicitly provided, use it directly # instead of computing from fps. This mirrors the behavior of # HF's Qwen3VLVideoProcessor.sample_frames where num_frames # and fps are mutually exclusive. if sampled_num_frames is not None: num_frames = sampled_num_frames else: # here video_fps is the fps of the sampled video, and # metadata["fps"] refers to the fps of the original video. sampled_fps = sampled_fps if sampled_fps else video_processor.fps num_frames = int(total_num_frames / metadata["fps"] * sampled_fps) num_frames = min( min( max(num_frames, video_processor.min_frames), video_processor.max_frames, ), total_num_frames, ) indices = ( np.linspace(0, total_num_frames - 1, num_frames) .round() .astype(int) .tolist() ) timestamps = self._calculate_timestamps(indices, video_fps, temporal_patch_size) return timestamps class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) image_token = "<|vision_start|><|image_pad|><|vision_end|>" video_token = "<|vision_start|><|video_pad|><|vision_end|>" return image_token * num_images + video_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) image_overrides = mm_options.get("image") video_overrides = mm_options.get("video") target_image_width, target_image_height = ( self.info.get_image_size_with_most_features() ) # treat videos as special images target_num_frames = 2 if video_overrides: assert isinstance(video_overrides, VideoDummyOptions) num_frames_override = video_overrides.num_frames if num_frames_override: if num_frames_override > target_num_frames: logger.warning( "video.num_frames override (%d) exceeds model's " "maximum number of frames (%d), will be ignored", num_frames_override, target_num_frames, ) if num_frames_override < 2: logger.warning( "video.num_frames override (%d) cannot be less " "than 2, will be ignored", num_frames_override, ) target_num_frames = min(target_num_frames, num_frames_override) target_num_frames = max(target_num_frames, 2) video_processor = self.info.get_video_processor() mm_kwargs = self.info.ctx.get_merged_mm_kwargs({}) video_size = mm_kwargs.get("size", video_processor.size) temporal_patch_size = mm_kwargs.get( "temporal_patch_size", video_processor.temporal_patch_size ) # video_max_pixels contains the temporal compression factor, # so we divide by 2 to get the maximum number of image pixels. video_max_pixels = video_size["longest_edge"] target_video_width, target_video_height = ( self.info.get_image_size_with_most_features( max_pixels=video_max_pixels // temporal_patch_size ) ) target_video_size, _ = self.info._get_vision_info( image_width=target_video_width, image_height=target_video_height, num_frames=target_num_frames, image_processor=video_processor, mm_kwargs={}, ) # NOTE: we need to do this check here since Qwen3-VL resizes video # frames depending on how many frames there are. target_video_width, target_video_height = ( target_video_size.width, target_video_size.height, ) if video_overrides: assert isinstance(video_overrides, VideoDummyOptions) width_override = video_overrides.width if width_override: if width_override > target_video_width: logger.warning( "video.width override (%d) exceeds model's " "maximum width (%d), will be ignored", width_override, target_video_width, ) target_video_width = min(target_video_width, width_override) height_override = video_overrides.height if height_override: if height_override > target_video_height: logger.warning( "video.height override (%d) exceeds model's " "maximum height (%d), will be ignored", height_override, target_video_height, ) target_video_height = min(target_video_height, height_override) return { "image": self._get_dummy_images( width=target_image_width, height=target_image_height, num_images=num_images, overrides=image_overrides, ), "video": self._get_dummy_videos( width=target_video_width, height=target_video_height, num_frames=target_num_frames, num_videos=num_videos, ), } def _get_dummy_videos( self, *, width: int, height: int, num_frames: int, num_videos: int, ) -> list[VideoItem]: video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] for i in range(num_videos): video_metadata = { "fps": 2.0, "duration": num_frames / 2.0, "total_num_frames": num_frames, "frames_indices": [i for i in range(num_frames)], "video_backend": "opencv", "do_sample_frames": False, } video_item = (video.copy(), video_metadata) video_items.append(video_item) return video_items class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) processor = self.info.get_hf_processor(**mm_kwargs) # Separate video processing from image processing. Because the videos # are processed into several image patches if videos := mm_data.pop("videos", []): video_grid_thw_lst = [] pixel_values_videos_lst = [] timestamps_per_video = [] for item in videos: video_array, metadata = item # NOTE: @JJJYmmm new attr metadata.frames_indices indicates # the sampled frames indices of pre-sampled videos, which is # used to calculate the timestamps. Make sure that # do_sample_frames in mm_kwargs is false for presampled videos. # NOTE: a copy of is created to update do_sample_frames, # otherwise mm_hash for the object will be incorrect. video_mm_kwargs = dict(**mm_kwargs) if "do_sample_frames" not in video_mm_kwargs: # qwen_vl_utils already has "do_sample_frames" in # mm_kwargs, don't overwrite it. video_mm_kwargs["do_sample_frames"] = metadata.get( "do_sample_frames", False ) metadata = VideoMetadata( **{k: metadata[k] for k in metadata if k != "do_sample_frames"} ) # Compute timestamps here where we have access to metadata timestamps = self.info._get_video_second_idx( metadata=metadata, do_sample_frames=video_mm_kwargs["do_sample_frames"], sampled_fps=video_mm_kwargs.get("fps"), sampled_num_frames=video_mm_kwargs.get("num_frames"), ) timestamps_per_video.append(timestamps) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] video_mm_data["video_metadata"] = [[metadata]] # When num_frames is specified, explicitly set fps=None # to prevent HF's BaseVideoProcessor.preprocess() from # filling in the class default (fps=2) via setdefault(), # which would conflict with num_frames (mutually exclusive). if "num_frames" in video_mm_kwargs and "fps" not in video_mm_kwargs: video_mm_kwargs["fps"] = None video_outputs = super()._call_hf_processor( prompt="<|vision_start|><|video_pad|><|vision_end|>", mm_data=video_mm_data, mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) merge_size = processor.video_processor.merge_size # Get video grid info for EVS calculation. video_grid_thw = video_outputs["video_grid_thw"] num_frames = int(video_grid_thw[0, 0]) tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // ( merge_size**2 ) # Apply EVS if enabled. video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate if video_pruning_rate is not None and video_pruning_rate > 0.0: num_tokens = compute_retained_tokens_count( tokens_per_frame=tokens_per_frame_base, num_frames=num_frames, q=video_pruning_rate, ) # Here we just need placeholders that won't actually be replaced - # we just need to make sure the total number of tokens is correct # assign all tokens to the first frame. tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) select_token_id = False else: tokens_per_frame = [tokens_per_frame_base] * num_frames select_token_id = True # Generate the video replacement with EVS-adjusted token counts tokenizer = self.info.get_tokenizer() hf_config = self.info.get_hf_config() video_repl = Qwen3VLMultiModalProcessor.get_video_repl( tokens_per_frame=tokens_per_frame, timestamps=timestamps, tokenizer=tokenizer, vision_start_token_id=hf_config.vision_start_token_id, vision_end_token_id=hf_config.vision_end_token_id, video_token_id=hf_config.video_token_id, select_token_id=select_token_id, ) # Convert token IDs to text for the HF processor flow video_placeholder = tokenizer.decode( video_repl.full, skip_special_tokens=False ) input_ids = video_outputs.pop("input_ids") video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( "<|vision_start|><|video_pad|><|vision_end|>", video_placeholder, 1, ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), timestamps=timestamps_per_video, ) else: video_outputs = dict() processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) combined_outputs = dict( processed_outputs, **video_outputs, ) return BatchFeature(combined_outputs) def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _create_qwen2vl_field_factory( self.info.get_hf_config().vision_config.spatial_merge_size )(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id vision_end_token_id = hf_config.vision_end_token_id merge_length = image_processor.merge_size**2 def get_image_replacement_qwen3vl(item_idx: int): out_item = out_mm_kwargs["image"][item_idx] grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length return [hf_processor.image_token_id] * num_tokens def get_video_replacement_qwen3vl(item_idx: int): out_item = out_mm_kwargs["video"][item_idx] grid_thw = out_item["video_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) sampled_fps = hf_processor_mm_kwargs.get("fps") if is_list_of(sampled_fps, float): sampled_fps = sampled_fps[item_idx] timestamps = out_item["timestamps"].data assert len(timestamps) == grid_thw[0], ( f"The timestamps length({len(timestamps)}) should be equal " f"video length ({grid_thw[0]})." ) # Compute tokens per frame, with EVS support num_frames = int(grid_thw[0]) tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate if video_pruning_rate is not None and video_pruning_rate > 0.0: num_tokens = compute_retained_tokens_count( tokens_per_frame=tokens_per_frame_base, num_frames=num_frames, q=video_pruning_rate, ) tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) select_token_id = False else: tokens_per_frame = [tokens_per_frame_base] * num_frames select_token_id = True return Qwen3VLMultiModalProcessor.get_video_repl( tokens_per_frame=tokens_per_frame, timestamps=timestamps, tokenizer=tokenizer, vision_start_token_id=vision_start_token_id, vision_end_token_id=vision_end_token_id, video_token_id=video_token_id, select_token_id=select_token_id, ) return [ PromptReplacement( modality="image", target=hf_processor.image_token, replacement=get_image_replacement_qwen3vl, ), # NOTE: We match string on purpose since searching sequence of # token ids takes more time. PromptReplacement( modality="video", target="<|vision_start|><|video_pad|><|vision_end|>", replacement=get_video_replacement_qwen3vl, ), ] @staticmethod def get_video_repl( *, tokens_per_frame: list[int], timestamps: list[float | int], tokenizer: TokenizerLike, vision_start_token_id: int, vision_end_token_id: int, video_token_id: int, select_token_id: bool = False, ) -> PromptUpdateDetails[list[int]]: """Build prompt replacement for a video in Qwen3VL format. The replacement structure for each frame is: timestamp_tokens + vision_start_token + video_tokens + vision_end_token Args: tokens_per_frame: Number of video tokens per frame (can vary per frame for EVS). timestamps: List of timestamps in seconds for each frame tokenizer: Tokenizer to encode timestamp strings vision_start_token_id: Token ID for vision start marker vision_end_token_id: Token ID for vision end marker video_token_id: Token ID for video content Returns: PromptUpdateDetails with full token sequence """ assert len(timestamps) == len(tokens_per_frame), ( "timestamps and tokens_per_frame must have the same length" ) # Tokenize timestamp strings independently to avoid tokenizer merging # tokens across boundaries. # TODO: switch to `_seq2tokens` which has some caching. timestamp_token_ids = [ tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False) for timestamp in timestamps ] # Build the full token sequence all_token_ids = [] for frame_timestamp_ids, num_tokens in zip( timestamp_token_ids, tokens_per_frame ): # Add timestamp tokens all_token_ids.extend(frame_timestamp_ids) # Add vision tokens: vision_start + video_tokens + vision_end all_token_ids.append(vision_start_token_id) all_token_ids.extend([video_token_id] * num_tokens) all_token_ids.append(vision_end_token_id) if select_token_id: return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id) # NOTE: we use `from_seq` instead of `select_token_id` because we want all # tokens in the placeholder to be initially marked as candidates. Then # in `get_input_embeddings``, we refine the mask to only replace # `video_token_id` / `image_token_id`` positions with video/image embeddings, # keeping text embeddings for timestamps and structural tokens. return PromptUpdateDetails.from_seq(all_token_ids) @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, # otherwise (seq_len, ). "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, # the same shape as input_embeds "deepstack_input_embeds": 0, } ) class Qwen3LLMModel(Qwen3Model): def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, # args for deepstack deepstack_input_embeds: IntermediateTensors | None = None, ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) for layer_idx, layer in islice( enumerate(self.layers), self.start_layer, self.end_layer ): hidden_states, residual = layer( positions, hidden_states, residual, ) if deepstack_input_embeds is not None and layer_idx in range( 0, len(deepstack_input_embeds) ): hidden_states = ( hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] ) self._maybe_add_hidden_state( aux_hidden_states, layer_idx + 1, hidden_states, residual ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states class Qwen3LLMForCausalLM(Qwen3ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3ForCausalLM, self).__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = Qwen3LLMModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix="lm_head", ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) @MULTIMODAL_REGISTRY.register_processor( Qwen3VLMultiModalProcessor, info=Qwen3VLProcessingInfo, dummy_inputs=Qwen3VLDummyInputsBuilder, ) class Qwen3VLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE, SupportsEagle, SupportsEagle3, SupportsMultiModalPruning, ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], "qkv": ["qkv"], # For vision tower's already-packed QKV } supports_encoder_tp_data = True # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.visual.": "visual.", "lm_head.": "language_model.lm_head.", "model.language_model.": "language_model.model.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): return "<|vision_start|><|video_pad|><|vision_end|>" raise ValueError("Only image or video modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): super().__init__() config: Qwen3VLConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.video_pruning_rate = multimodal_config.video_pruning_rate self.is_multimodal_pruning_enabled = ( multimodal_config.is_multimodal_pruning_enabled() ) self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") self.deepstack_num_level = ( len(config.vision_config.deepstack_visual_indexes) if self.use_deepstack else 0 ) self.visual_dim = config.vision_config.out_hidden_size self.multiscale_dim = self.visual_dim * self.deepstack_num_level with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), ) # register buffer for deepstack if self.use_deepstack: self.deepstack_input_embeds = [ torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, config.text_config.hidden_size, ) for _ in range(self.deepstack_num_level) ] with self._mark_language_model(vllm_config): self.language_model = Qwen3LLMForCausalLM( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=maybe_prefix(prefix, "language_model"), ) if not get_pp_group().is_first_rank and hasattr( config.vision_config, "deepstack_visual_indexes" ): assert self.language_model.start_layer >= len( config.vision_config.deepstack_visual_indexes ), ( "start_layer should be greater than or equal to " "len(deepstack_visual_indexes)" ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) def _get_deepstack_input_embeds( self, num_tokens: int, ) -> IntermediateTensors | None: if not getattr(self, "deepstack_input_embeds", None): return None # If vision tower is skipped # get deepstack_input_embeds from buffer, and clear the buffer return IntermediateTensors( { f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][ :num_tokens ] for idx in range(self.deepstack_num_level) } ) def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: if not getattr(self, "deepstack_input_embeds", None): return # set deepstack_input_embeds to buffer num_tokens = deepstack_input_embeds.size(1) if num_tokens > self.deepstack_input_embeds[0].size(0): self.deepstack_input_embeds = [ torch.zeros( num_tokens, self.config.text_config.hidden_size, device=self.deepstack_input_embeds[0].device, dtype=self.deepstack_input_embeds[0].dtype, ) for _ in range(self.deepstack_num_level) ] for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].copy_( deepstack_input_embeds[idx] ) def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: if not getattr(self, "deepstack_input_embeds", None): return # clear deepstack_input_embeds in buffer if num_tokens > 0: for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].zero_() def _parse_and_validate_image_input( self, **kwargs: object ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: return Qwen2_5_VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) if image_embeds is not None: return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw, ) def _parse_and_validate_video_input( self, **kwargs: object ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) timestamps = kwargs.pop("timestamps", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, timestamps=timestamps, ) if video_embeds is not None: return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, timestamps=timestamps, ) def _process_image_input( self, image_input: Qwen2_5_VLImageInputs ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" ) else: image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _process_video_input( self, video_input: Qwen2_5_VLVideoInputs ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype ) if self.use_data_parallel: grid_thw_list = grid_thw.tolist() return run_dp_sharded_mrope_vision_model( self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" ) else: video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _postprocess_image_embeds_evs( self, image_embeds_split: tuple[torch.Tensor, ...], image_input: Qwen2_5_VLImageInputs, ) -> tuple[torch.Tensor, ...]: """ Append mrope positions for each for images. This is necessary to recover correct mrope positions after video pruning Args: image_embeds_split: Tuple of image embeddings for each image item. image_input: Image input data. Returns: Tuple of image embeddings for each image item. Resulting embeddings will have extra 5 channels for computed mrope positions, consistent with video embeddings. """ if self.is_multimodal_pruning_enabled: merge_size = self.visual.spatial_merge_size grid_thw = image_input["image_grid_thw"] grid_thw_list = grid_thw.tolist() image_embeds_out = [] for emb, size in zip(image_embeds_split, grid_thw_list): positions = compute_mrope_for_media(size, merge_size).to(emb.device) positions = torch.cat( [ positions, torch.zeros_like( positions[:, 0:1] ), # Dummy extra fifth channel ], dim=1, ) emb = torch.cat([emb, positions], dim=1) image_embeds_out.append(emb) image_embeds_split = tuple(image_embeds_out) return image_embeds_split def _postprocess_video_embeds_evs( self, video_embeds_split: tuple[torch.Tensor, ...], video_input: Qwen2_5_VLVideoInputs, ) -> tuple[torch.Tensor, ...]: """ Prunes video embeddings via Efficient Video Sampling (EVS) and then appends mrope positions for each retained embeddings Args: video_embeds_split: Tuple of video embeddings for each video item. video_input: Video input data. Returns: Tuple of video embeddings for each video item. Resulting embeddings will have extra 5 channels for computed mrope positions, and whether the index corresponds to a video embedding. """ grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() merge_size = self.visual.spatial_merge_size # Apply EVS to each video. video_embeds_out = [] for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)): # Compute positions. timestamps = video_input.timestamps[video_idx] num_frames = len(timestamps) t, h, w = size if self.is_multimodal_pruning_enabled: # For each video, compute retention mask using EVS. # retention_mask: [11424]. retention_mask = compute_retention_mask( emb, size, spatial_merge_size=self.visual.spatial_merge_size, q=self.video_pruning_rate, ) # Apply retention mask. emb = emb[retention_mask] # Calculate the actual number of retained tokens per frame. num_frames, rows, cols = ( t, h // merge_size, w // merge_size, ) retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) num_tokens_per_frame = ( retention_mask_thw.sum(dim=(1, 2)).long().tolist() ) else: feature_size = emb.shape[0] // num_frames num_tokens_per_frame = [feature_size] * num_frames retention_mask = None emb = self._create_final_video_embeddings( video_embeddings=emb, num_tokens_per_frame=num_tokens_per_frame, timestamps=timestamps, video_grid_thw=size, retention_mask=retention_mask, ) video_embeds_out.append(emb) return tuple(video_embeds_out) def _create_final_video_embeddings( self, video_embeddings: torch.Tensor, num_tokens_per_frame: list[int], timestamps: list[float], video_grid_thw: list[int], retention_mask: torch.Tensor, ) -> torch.Tensor: """Create final embeddings that combine video embeddings with text embeddings of indicator tokens. These final embeddings contain: - Actual video embeddings in positions corresponding to video content - Text embeddings for indicator tokens (, , and frame separation text) in their respective positions These embeddings will replace the placeholder embeddings to create input_embeds for the LLM. """ device = video_embeddings.device # Generate video replacement token IDs using get_video_repl # This tokenizes each frame separator independently, then uses pre-tokenized # special tokens to ensure consistent tokenization regardless of # num_tokens_per_frame values. video_repl = Qwen3VLMultiModalProcessor.get_video_repl( tokens_per_frame=num_tokens_per_frame, tokenizer=self._tokenizer, timestamps=timestamps, vision_start_token_id=self.config.vision_start_token_id, vision_end_token_id=self.config.vision_end_token_id, video_token_id=self.config.video_token_id, select_token_id=self.is_multimodal_pruning_enabled, ) repl_token_ids = torch.tensor(video_repl.full, device=device) embed_token_id = _cached_tensor(self.config.video_token_id, device=device) is_video_embed = torch.isin(repl_token_ids, embed_token_id) # Get text embeddings for indicator tokens (has only `visual_dim``). text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids) if self.use_deepstack: ( deepstack_input_embeds, multimodal_embeddings, ) = self._compute_deepstack_embeds( inputs_embeds=text_embeddings, multimodal_embeddings=[video_embeddings], is_multimodal=is_video_embed, ) else: deepstack_input_embeds = None multimodal_embeddings = [video_embeddings] merged_embeddings = _merge_multimodal_embeddings( inputs_embeds=text_embeddings, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_video_embed, ) to_concat = [merged_embeddings] if deepstack_input_embeds is not None: to_concat.append( deepstack_input_embeds.permute(1, 0, 2).reshape( deepstack_input_embeds.shape[1], -1 ) ) expanded_positions = None if self.is_multimodal_pruning_enabled: is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id) expanded_positions = self._get_expanded_positions( device=merged_embeddings.device, seq_len=merged_embeddings.shape[0], video_grid_thw=video_grid_thw, num_tokens_per_frame=num_tokens_per_frame, timestamps=timestamps, is_video_embed=is_video_embed, is_vision_start=is_vision_start, retention_mask=retention_mask, ) to_concat.append(expanded_positions) final_video_embeddings = torch.cat(to_concat, dim=-1) return final_video_embeddings def _get_expanded_positions( self, device, seq_len, video_grid_thw, num_tokens_per_frame, timestamps, is_video_embed, is_vision_start, retention_mask, ): embed_token_id = _cached_tensor(self.config.video_token_id, device=device) # Expand positions to match the full sequence length # (includes both video tokens and indicator tokens) # Shape: [full_length, 5] where positions are filled for video tokens # and zeros for indicator tokens. # Channel 3 flags VISION_START tokens so that # recompute_mrope_positions can reliably count timestamp tokens # (even when early frames have all video tokens pruned). # Channel 4 flags video-embedding tokens. expanded_positions = torch.zeros( seq_len, 5, # [t_index, h_index, w_index, is_vision_start, is_video] device=device, dtype=torch.long, ) _, h, w = video_grid_thw merge_size = self.visual.spatial_merge_size num_frames = len(num_tokens_per_frame) unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl( tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames, tokenizer=self._tokenizer, timestamps=timestamps, vision_start_token_id=self.config.vision_start_token_id, vision_end_token_id=self.config.vision_end_token_id, video_token_id=self.config.video_token_id, ).full unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device) mm_feature = MultiModalFeatureSpec( data=MultiModalKwargsItem( { "video_grid_thw": MultiModalFieldElem( data=torch.tensor(video_grid_thw), field=None, # HACK. ), } ), modality="video", identifier="DUMMY", mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)), ) original_mrope = ( self.get_mrope_input_positions( input_tokens=unpruned_token_ids, mm_features=[mm_feature], )[0] .to(device) .permute(1, 0) ) full_is_video_embed = unpruned_token_ids_tensor == embed_token_id expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][ retention_mask ] expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed] expanded_positions[..., 3] = is_vision_start expanded_positions[..., 4] = is_video_embed return expanded_positions def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: if ( input_key in ("pixel_values", "image_embeds") and "image" not in mm_input_by_modality ): mm_input_by_modality["image"] = self._parse_and_validate_image_input( **kwargs ) if ( input_key in ("pixel_values_videos", "video_embeds") and "video" not in mm_input_by_modality ): mm_input_by_modality["video"] = self._parse_and_validate_video_input( **kwargs ) return mm_input_by_modality @staticmethod def _iter_mm_grid_hw( input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], video_token_id: int, vision_start_token_id: int, vision_end_token_id: int, spatial_merge_size: int, ) -> Iterator[tuple[int, int, int, int]]: """Iterate over multimodal features and yield position info. Args: input_tokens: List of token IDs in the input sequence. mm_features: List of multimodal feature specifications containing image/video data and position information. video_token_id: Token ID used for video tokens. vision_start_token_id: Token ID marking the start of a vision sequence. vision_end_token_id: Token ID marking the end of a vision sequence. spatial_merge_size: Size of the spatial merge operation used to compute logical grid dimensions from the original feature grid. Yields: offset: Position of the first video/image token in the sequence. llm_grid_h: Logical grid height (may not match actual token count with EVS). llm_grid_w: Logical grid width (may not match actual token count with EVS). actual_num_tokens: Actual number of video/image tokens in the placeholder. """ for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): offset = mm_feature.mm_position.offset if mm_feature.modality == "image": t, h, w = mm_feature.data["image_grid_thw"].data.tolist() assert t == 1, f"Image must have 1 frame, got {t}" llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w elif mm_feature.modality == "video": t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size for _ in range(t): # When EVS is enabled, some frames may have 0 video tokens in the # placeholder. We use `vision_start_token_id` to locate each frame # since it is always present for every frame. # We then look for the first `video_token_id` after # `vision_start_token_id` and before `vision_end_token_id`. offset = input_tokens.index(vision_start_token_id, offset) vision_end_offset = input_tokens.index(vision_end_token_id, offset) try: actual_num_tokens = 0 video_offset = input_tokens.index( video_token_id, offset, vision_end_offset ) # NOTE: looking at the # `Qwen3VLMultiModalProcessor.get_video_repl` code, we can # see that we can use the below formula to get the token # count, since everything in between `video_offset` and # `vision_end_offset` is populated as `video_token_id`. # This saves us from manually counting the number tokens # that match `video_token_id` in between. actual_num_tokens += vision_end_offset - video_offset except ValueError: # No `video_token_id` in this frame (EVS with 0 tokens for # this frame) -> use `offset + 1`` to move past # `vision_start_token_id`. video_offset = offset + 1 yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens # Move offset past this frame for next iteration. offset = vision_end_offset + 1 else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: return self._get_mrope_input_positions( input_tokens=input_tokens, mm_features=mm_features, config=self.config, ) @staticmethod def _get_mrope_input_positions( input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], config: Qwen3VLConfig, ): llm_pos_ids_list = [] st = 0 for ( offset, llm_grid_h, llm_grid_w, actual_num_tokens, ) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw( input_tokens, mm_features, video_token_id=config.video_token_id, vision_start_token_id=config.vision_start_token_id, vision_end_token_id=config.vision_end_token_id, spatial_merge_size=config.vision_config.spatial_merge_size, ): # Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere) if actual_num_tokens == 0: continue text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) # Check if this is a "lumped placeholder" (all tokens from multiple frames # assigned to the 0-th frame - see # `Qwen3VLMultiModalProcessor.get_video_repl`. expected_tokens_per_frame = llm_grid_h * llm_grid_w if actual_num_tokens > expected_tokens_per_frame: # Lumped placeholder: create grid positions for all "logical" frames # represented. num_logical_frames = actual_num_tokens // expected_tokens_per_frame remainder = actual_num_tokens % expected_tokens_per_frame # Create positions for complete frames. for _ in range(num_logical_frames): grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape( 3, -1 ) llm_pos_ids_list.append(grid_indices + text_len + st_idx) st_idx = llm_pos_ids_list[-1].max() + 1 text_len = 0 # No text between frames within the lump # Handle remainder tokens if any (partial frame). # NOTE: this should never be the case. Should we have an assert? if remainder > 0: # Create a partial grid - take first 'remainder' positions full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) grid_indices = full_grid[:, :remainder] llm_pos_ids_list.append(grid_indices + text_len + st_idx) else: # Normal case: frame has exactly the expected tokens (after actual EVS # pruning). grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) llm_pos_ids_list.append(grid_indices + text_len + st_idx) st = offset + actual_num_tokens if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return torch.from_numpy(llm_positions), mrope_position_delta def recompute_mrope_positions( self, input_ids: list[int], multimodal_embeddings: MultiModalEmbeddings, mrope_positions: torch.LongTensor, num_computed_tokens: int, ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]: """ Update part of input mrope positions (starting with num_computed_tokens index). Original mrope_positions are computed for unpruned sequence and becomes incorrect once pruning occurs, so once we prune media tokens we should reflect this in the mrope_positions before we feed it to LLM. Args: input_ids: (N,) All input tokens of the prompt containing entire sequence. multimodal_embeddings: Tuple of multimodal embeddings that fits into the prefill chunk that is being processed. mrope_positions: Existing mrope positions (3, N) for entire sequence num_computed_tokens: A number of computed tokens so far. Returns: Tuple of (multimodal_embeddings, mrope_positions, mrope_position_delta). """ return self._recompute_mrope_positions( input_ids=input_ids, multimodal_embeddings=multimodal_embeddings, mrope_positions=mrope_positions, num_computed_tokens=num_computed_tokens, image_token_id=self.config.image_token_id, video_token_id=self.config.video_token_id, vision_start_token_id=self.config.vision_start_token_id, ) @staticmethod def _recompute_mrope_positions( input_ids: list[int], multimodal_embeddings: MultiModalEmbeddings, mrope_positions: torch.LongTensor, num_computed_tokens: int, vision_start_token_id: int, image_token_id: int, video_token_id: int, ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]: # Device device = ( multimodal_embeddings[0].device if len(multimodal_embeddings) else mrope_positions.device ) # Tensors input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) mm_embeddings_out = [] mm_embeddings_pos = [] # Strip position information from embeddings (last 5 channels) # For Qwen3 VL, handle potentially empty frames (from unpacking) for mm in multimodal_embeddings: if mm.shape[0] > 0: # Only process non-empty frames mm_embeddings_out.append(mm[:, :-5]) mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long()) else: # Empty frame - keep as is mm_embeddings_out.append(mm) # Create empty position tensor with correct shape mm_embeddings_pos.append( torch.empty(5, 0, device=device, dtype=torch.long) ) positions, mrope_positions_delta = recompute_mrope_positions( input_ids_t, mm_embeddings_pos, mrope_positions, num_computed_tokens, vision_start_token_id, image_token_id, video_token_id, ) return tuple(mm_embeddings_out), positions, mrope_positions_delta def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: list[torch.Tensor] = [] # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._postprocess_image_embeds_evs( image_embeddings, multimodal_input ) multimodal_embeddings.extend(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) if self.is_multimodal_pruning_enabled: video_embeddings = self._postprocess_video_embeds_evs( video_embeddings, multimodal_input ) multimodal_embeddings.extend(video_embeddings) embeddings_tuple = tuple(multimodal_embeddings) return embeddings_tuple def _compute_deepstack_embeds( self, inputs_embeds: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings, is_multimodal: torch.Tensor, ) -> tuple[torch.Tensor, MultiModalEmbeddings]: visual_lens = [len(x) for x in multimodal_embeddings] multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) ( multimodal_embeddings_main, multimodal_embeddings_multiscale, ) = torch.split( multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim], dim=-1, ) multimodal_embeddings = torch.split( multimodal_embeddings_main, visual_lens, dim=0 ) multimodal_embeddings_multiscale = torch.split( multimodal_embeddings_multiscale, visual_lens, dim=0 ) deepstack_input_embeds = inputs_embeds.new_zeros( inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1) ) deepstack_input_embeds = _merge_multimodal_embeddings( inputs_embeds=deepstack_input_embeds, multimodal_embeddings=multimodal_embeddings_multiscale, is_multimodal=is_multimodal, ) deepstack_input_embeds = deepstack_input_embeds.view( inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim ) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) return deepstack_input_embeds, multimodal_embeddings def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, ) -> torch.Tensor: inputs_embeds = self._embed_text_input_ids( input_ids, self.language_model.embed_input_ids, is_multimodal=is_multimodal, ) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds is_multimodal = _require_is_multimodal(is_multimodal) if self.use_deepstack: ( deepstack_input_embeds, multimodal_embeddings, ) = self._compute_deepstack_embeds( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, ) else: deepstack_input_embeds = None inputs_embeds = _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, ) if deepstack_input_embeds is not None: self._set_deepstack_input_embeds(deepstack_input_embeds) return inputs_embeds def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen3VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a batch. **NOTE**: If mrope is enabled (default setting for Qwen3VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). intermediate_tensors: Intermediate tensors from previous pipeline stages. inputs_embeds: Pre-computed input embeddings. **kwargs: Additional keyword arguments including: - pixel_values: Pixel values to be fed to a model. `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ if intermediate_tensors is not None: inputs_embeds = None if inputs_embeds is not None and get_pp_group().is_first_rank: deepstack_input_embeds = self._get_deepstack_input_embeds( inputs_embeds.size(0) ) else: deepstack_input_embeds = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, # args for deepstack deepstack_input_embeds=deepstack_input_embeds, ) if inputs_embeds is not None and get_pp_group().is_first_rank: self._clear_deepstack_input_embeds(inputs_embeds.size(0)) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="language_model", connector=["visual.merger", "visual.deepstack_merger_list"], tower_model="visual.", ) def get_num_mm_encoder_tokens( self, num_image_tokens: int, ) -> int: hf_config = self.config vision_config = hf_config.vision_config merge_size = vision_config.spatial_merge_size return num_image_tokens * merge_size**2 def get_num_mm_connector_tokens( self, num_vision_tokens: int, ) -> int: hf_config = self.config vision_config = hf_config.vision_config merge_size = vision_config.spatial_merge_size return num_vision_tokens // merge_size**2 @lru_cache def _cached_tensor(x, device) -> torch.Tensor: return torch.tensor(x, device=device)