qwen3_vl.py 79.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
26

27
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
28
from functools import lru_cache, partial
29
from itertools import islice
30
from typing import Any
31

32
import os
33
34
35
36
37
38
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
39
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
40
41
42
    smart_resize as image_smart_resize,
)
from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor
43
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
44
45
46
    Qwen3VLConfig,
    Qwen3VLVisionConfig,
)
47
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
48
49
    smart_resize as video_smart_resize,
)
50
51
52
from transformers.video_utils import VideoMetadata

from vllm.compilation.decorators import support_torch_compile
53
from vllm.config import VllmConfig
54
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
55
56
57
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
58
from vllm.model_executor.layers.conv import Conv3dLayer
59
60
61
62
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
63
64
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
65
from vllm.model_executor.layers.rotary_embedding import get_rope
66
67
68
69
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
70
71
72
73
74
75
from vllm.multimodal.evs import (
    compute_mrope_for_media,
    compute_retained_tokens_count,
    compute_retention_mask,
    recompute_mrope_positions,
)
76
77
from vllm.multimodal.inputs import (
    MultiModalDataDict,
78
    MultiModalFeatureSpec,
79
80
81
    MultiModalFieldConfig,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
82
    PlaceholderRange,
83
84
85
86
    VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
87
    BaseDummyInputsBuilder,
88
89
90
91
92
    BaseMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
93
from vllm.sequence import IntermediateTensors
94
from vllm.utils.collection_utils import is_list_of
95
from vllm.utils.math_utils import round_up
96
from vllm.v1.attention.backends.registry import AttentionBackendEnum
97

98
99
from .interfaces import (
    MultiModalEmbeddings,
100
    SupportsEagle3,
101
    SupportsLoRA,
102
    SupportsMRoPE,
103
    SupportsMultiModal,
104
    SupportsMultiModalPruning,
105
    SupportsPP,
106
    _require_is_multimodal,
107
108
109
110
111
112
113
114
115
116
)
from .qwen2_5_vl import (
    Qwen2_5_VisionAttention,
    Qwen2_5_VLImageEmbeddingInputs,
    Qwen2_5_VLImageInputs,
    Qwen2_5_VLImagePixelInputs,
    Qwen2_5_VLVideoEmbeddingInputs,
    Qwen2_5_VLVideoInputs,
    Qwen2_5_VLVideoPixelInputs,
)
117
118
119
120
121
from .qwen2_vl import (
    Qwen2VLMultiModalDataParser,
    Qwen2VLProcessingInfo,
    _create_qwen2vl_field_factory,
)
122
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
123
124
125
126
127
128
129
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    _merge_multimodal_embeddings,
    maybe_prefix,
)
130
131
from .vision import (
    get_vit_attn_backend,
132
    is_vit_use_data_parallel,
133
134
    run_dp_sharded_mrope_vision_model,
)
135
136
137

logger = init_logger(__name__)

138
139
140
# We use 2048 dummy video frames that would generate vision embeddings
# of the maximum size.
DUMMY_VIDEO_NUM_FRAMES = 2048
141

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

class Qwen3_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        hidden_size: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
157
158
        self.proj = Conv3dLayer(
            in_channels,
159
            hidden_size,
160
161
            kernel_size=kernel_size,
            stride=kernel_size,
162
163
            bias=True,
        )
164
165
166

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
167
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
168
169
        if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
            x = x.to(memory_format=torch.channels_last_3d)
170
171
172
173
174
        x = self.proj(x).view(L, self.hidden_size)
        return x


class Qwen3_VisionMLP(nn.Module):
175
176
177
178
179
180
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
181
        quant_config: QuantizationConfig | None = None,
182
183
        prefix: str = "",
    ):
184
        super().__init__()
185
        use_data_parallel = is_vit_use_data_parallel()
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        self.linear_fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
        self.linear_fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        self.act_fn = act_fn

    def forward(self, x: torch.Tensor):
        mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return mlp_output


class Qwen3_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
218
219
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
220
221
222
223
224
225
226
        prefix: str = "",
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
227
228
229
230
231
232
        self.attn = Qwen2_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
233
234
235
236
237
238
239
240
241
        )
        self.mlp = Qwen3_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_fn=act_fn,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
242
243

    def forward(
244
245
246
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
247
248
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
249
        max_seqlen: torch.Tensor,  # Only used for Flash Attention
250
    ) -> torch.Tensor:
251
252
253
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
254
255
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
256
257
            max_seqlen=max_seqlen,
        )
258
259
260
261
262
263
264
265
266
267

        x = x + self.mlp(self.norm2(x))
        return x


class Qwen3_VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
268
        norm_layer: Callable[[int], nn.Module] | None = None,
269
270
        spatial_merge_size: int = 2,
        use_postshuffle_norm: bool = False,
271
        quant_config: QuantizationConfig | None = None,
272
273
274
        prefix: str = "",
    ) -> None:
        super().__init__()
275
        use_data_parallel = is_vit_use_data_parallel()
276
277
278
279
280
281
282
283
        self.hidden_size = context_dim * (spatial_merge_size**2)

        self.use_postshuffle_norm = use_postshuffle_norm
        if self.use_postshuffle_norm:
            context_dim = self.hidden_size

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
284
        self.norm = norm_layer(context_dim)
285
286
287
288
289
290
291
292
        self.linear_fc1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
293
        self.act_fn = nn.GELU()
294
295
296
297
298
299
300
301
        self.linear_fc2 = RowParallelLinear(
            self.hidden_size,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_postshuffle_norm:
            x = self.norm(x.view(-1, self.hidden_size))
        else:
            x = self.norm(x).view(-1, self.hidden_size)

        x_parallel, _ = self.linear_fc1(x)
        x_parallel = self.act_fn(x_parallel)
        out, _ = self.linear_fc2(x_parallel)
        return out


class Qwen3_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen3VLVisionConfig,
        norm_eps: float = 1e-6,
320
        quant_config: QuantizationConfig | None = None,
321
322
323
324
325
326
327
328
329
330
331
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.num_position_embeddings = vision_config.num_position_embeddings
        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.spatial_merge_unit = self.spatial_merge_size**2
        self.temporal_patch_size = vision_config.temporal_patch_size
        self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
332
        self.num_grid_per_side = int(self.num_position_embeddings**0.5)
333
334
335

        # NOTE: This is used for creating empty tensor for all_gather for
        # DP ViT. Here out_hidden_size is enlarged due to deepstack
336
337
338
        self.out_hidden_size = vision_config.out_hidden_size * (
            1 + len(self.deepstack_visual_indexes)
        )
339
340
341
342
343
344
345
346

        self.patch_embed = Qwen3_VisionPatchEmbed(
            patch_size=self.patch_size,
            temporal_patch_size=self.temporal_patch_size,
            in_channels=vision_config.in_channels,
            hidden_size=self.hidden_size,
        )

347
        self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
348
349
350

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
351
352
353
354
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
355
            rope_parameters={"partial_rotary_factor": 0.5},
356
        )
357
358
359
360
361
362
363
364
365
366

        self.merger = Qwen3_VisionPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=self.hidden_size,
            norm_layer=norm_layer,
            spatial_merge_size=self.spatial_merge_size,
            quant_config=quant_config,
            prefix=f"{prefix}.merger",
        )

367
368
369
370
371
372
373
374
375
376
377
378
379
380
        self.deepstack_merger_list = nn.ModuleList(
            [
                Qwen3_VisionPatchMerger(
                    d_model=vision_config.out_hidden_size,
                    context_dim=self.hidden_size,
                    spatial_merge_size=self.spatial_merge_size,
                    use_postshuffle_norm=True,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
                )
                for layer_idx in range(len(self.deepstack_visual_indexes))
            ]
        )
381
382

        self.attn_backend = get_vit_attn_backend(
383
384
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
385
        )
386
387

        if self.attn_backend not in {
388
389
390
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
391
392
        }:
            raise RuntimeError(
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                f"Qwen3-VL does not support {self.attn_backend} backend now."
            )
        self.blocks = nn.ModuleList(
            [
                Qwen3_VisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.intermediate_size,
                    act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(vision_config.depth)
            ]
        )
409
410
411
412
413
414
415
416
417

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    @staticmethod
    @lru_cache(maxsize=1024)
    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
        h_div = h // spatial_merge_size
        w_div = w // spatial_merge_size
        hpos_ids = hpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
        wpos_ids = wpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))

445
446
    def rot_pos_emb(self, grid_thw: list[list[int]]):
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
447
448
449
450
451
452
        pos_ids = [
            self.rot_pos_ids(h, w, self.spatial_merge_size)
            if t == 1
            else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
            for t, h, w in grid_thw
        ]
453
        pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
454
455
456

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
457

458
459
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
460
461

        return cos_combined, sin_combined
462

463
    def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
464
465
466
        num_grid_per_side = self.num_grid_per_side
        m_size = self.spatial_merge_size
        hidden_dim = self.pos_embed.embedding_dim
467

468
        outputs = []
469
        for t, h, w in grid_thw:
470
471
472
473
474
475
            h_idxs = torch.linspace(
                0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
            )
            w_idxs = torch.linspace(
                0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
            )
476
477
478
479
480
481
482
483
484

            h_floor = h_idxs.to(torch.long)
            w_floor = w_idxs.to(torch.long)
            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

            dh = h_idxs - h_floor
            dw = w_idxs - w_floor

485
            # Create meshgrid view for all h, w vars
486
487
488
            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
489
490
491
492
493
494
495
496
497
498
499

            # original computation of weights
            # w00 = (1 - dh_grid) * (1 - dw_grid)
            # w01 = (1 - dh_grid) * dw_grid
            # w10 = dh_grid * (1 - dw_grid)
            # w11 = dh_grid * dw_grid
            # we reuse w11 here to avoid duplicate
            # dh_grid * dw_grid computation
            w11 = dh_grid * dw_grid
            w10 = dh_grid - w11
            w01 = dw_grid - w11
500
            w00 = 1 - dh_grid - w01
501

502
503
504
            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
            h_grid_idx = h_grid * num_grid_per_side
505

506
            indices = (h_grid_idx + w_grid).reshape(4, -1)
507
            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
508
            weights = weights.to(dtype=self.dtype)
509
510

            embeds = self.pos_embed(indices)
511
512
            embeds *= weights
            combined = embeds.sum(dim=0)
513

514
515
            combined = combined.reshape(
                h // m_size, m_size, w // m_size, m_size, hidden_dim
516
            )
517
518
            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
519
520
521
            outputs.append(repeated)

        return torch.cat(outputs, dim=0)
522
523
524
525

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
526
    ) -> torch.Tensor:
527
        max_seqlen = torch.zeros([], device=cu_seqlens.device)
528
        if (
529
530
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
531
        ):
532
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
533
        return max_seqlen
534
535
536
537

    def forward(
        self,
        x: torch.Tensor,
538
        grid_thw: torch.Tensor | list[list[int]],
539
    ) -> torch.Tensor:
540
        hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
541
542
        hidden_states = self.patch_embed(hidden_states)

543
544
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
545
            grid_thw = np.array(grid_thw, dtype=np.int32)
546
547
        else:
            grid_thw_list = grid_thw.tolist()
548
            grid_thw = grid_thw.numpy()
549
550

        pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
551
        hidden_states = hidden_states + pos_embeds
552
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
553

554
555
556
557
558
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
559
560

        hidden_states = hidden_states.unsqueeze(1)
561
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
562
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
563
564
565

        deepstack_feature_lists = []
        for layer_num, blk in enumerate(self.blocks):
566
567
568
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
569
570
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
571
572
                max_seqlen=max_seqlen,
            )
573
            if layer_num in self.deepstack_visual_indexes:
574
575
576
577
                deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
                deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](
                    hidden_states
                )
578
579
580
                deepstack_feature_lists.append(deepstack_feature)
        hidden_states = self.merger(hidden_states)
        hidden_states = torch.cat(
581
582
            [hidden_states] + deepstack_feature_lists, dim=1
        )  # [seq_len, hidden_size * (1 + depth_of_deepstack)]
583
584
        return hidden_states

585
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
586
587
588
589
590
591
592
593
594
595
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
596
            for param_name, weight_name, shard_id in stacked_params_mapping:
597
598
599
600
601
602
603
604
605
606
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
607
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Qwen3VLConfig)

    def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
        return self.ctx.get_hf_processor(
            Qwen3VLProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )

624
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
625
626
627
628
629
630
631
632
633
634
635
636
        return self.get_hf_processor(**kwargs).image_processor

    def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 2,
        do_resize: bool = True,
637
        image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None,
638
    ) -> tuple[ImageSize, int]:
639
640
641
        if image_processor is None and num_frames > 1:
            image_processor = self.get_video_processor()
        elif image_processor is None:
642
643
            image_processor = self.get_image_processor()

644
645
        is_video = isinstance(image_processor, Qwen3VLVideoProcessor)

646
647
648
649
650
651
652
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size

        if do_resize:
653
654
655
656
            if is_video:
                smart_resize = video_smart_resize
                extra_kwargs = {
                    "num_frames": num_frames,
657
                    "temporal_factor": temporal_patch_size,
658
659
660
661
                }
            else:
                smart_resize = image_smart_resize
                extra_kwargs = {}
662
663
664
665
666
667
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.size["shortest_edge"],
                max_pixels=image_processor.size["longest_edge"],
668
                **extra_kwargs,
669
            )
670
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
671
        else:
672
            preprocessed_size = ImageSize(width=image_width, height=image_height)
673

674
        padded_num_frames = round_up(num_frames, temporal_patch_size)
675
676
677
678
679
680
681
682
683
684

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

685
686
687
688
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int:
        return super()._get_max_video_frames(
            max_tokens, start_num_frames=start_num_frames
        )
689
690
691
692
693
694
695

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        return super().get_num_frames_with_most_features(
696
            seq_len, mm_counts, max_frames_per_video=DUMMY_VIDEO_NUM_FRAMES
697
        )
698
699
700
701
702
703

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
704
705
706
707
708
709
710
        video_processor = self.get_video_processor()
        video_max_pixels = video_processor.size["longest_edge"]
        # video_max_pixels contains the temporal compression factor,
        # so we divide by 2 to get the maximum number of image pixels.
        target_width, target_height = self.get_image_size_with_most_features(
            max_pixels=video_max_pixels // video_processor.temporal_patch_size
        )
711
        num_video_soft_tokens = self.get_num_video_tokens(
712
713
            image_width=target_width,
            image_height=target_height,
714
            num_frames=2,
715
716
            image_processor=None,
        )
717
        return num_video_soft_tokens
718

719
720
721
    def _calculate_timestamps(
        self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
    ):
722
723
724
725
        if not isinstance(indices, list):
            indices = indices.tolist()
        if len(indices) % merge_size != 0:
            # don't update metadata's frames_indices directly
726
            indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size)
727
        timestamps = [idx / video_fps for idx in indices]
728
729
730
731
        timestamps = [
            (timestamps[i] + timestamps[i + merge_size - 1]) / 2
            for i in range(0, len(timestamps), merge_size)
        ]
732
733
734
        return timestamps

    def _get_video_second_idx(
735
736
737
        self,
        metadata: dict[str, Any],
        out_item: MultiModalKwargsItem,
738
739
        do_sample_frames: bool | None = None,
        sampled_fps: float | None = None,
740
    ) -> list[int]:
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        video_processor = self.get_video_processor()
        merge_size = video_processor.merge_size
        indices = metadata["frames_indices"]

        # metadata["fps"] refers to the true fps of the input video.
        video_fps = metadata["fps"]
        if do_sample_frames is None:
            do_sample_frames = metadata.get("do_sample_frames", False)

        # If video frames are sampled in HF processor (instead of vLLM
        # video loader), we need to re-calculate the indices from original
        # metadata.
        if do_sample_frames:
            # here video_fps is the fps of the sampled video, and
            # metadata["fps"] refers to the fps of the original video.
756
            sampled_fps = sampled_fps if sampled_fps else video_processor.fps
757
            total_num_frames = metadata["total_num_frames"]
758
            num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
759
            num_frames = min(
760
761
762
763
764
765
766
767
768
769
770
771
                min(
                    max(num_frames, video_processor.min_frames),
                    video_processor.max_frames,
                ),
                total_num_frames,
            )
            indices = (
                np.linspace(0, total_num_frames - 1, num_frames)
                .round()
                .astype(int)
                .tolist()
            )
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        timestamps = self._calculate_timestamps(indices, video_fps, merge_size)
        return timestamps


class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_token = "<|vision_start|><|image_pad|><|vision_end|>"
        video_token = "<|vision_start|><|video_pad|><|vision_end|>"

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
790
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
791
792
793
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
794
795
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None
796

797
798
        target_image_width, target_image_height = (
            self.info.get_image_size_with_most_features()
799
        )
800

801
802
        # treat videos as special images
        target_num_frames = 2
803
804
805
806
807
808
809
810
        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            num_frames_override = video_overrides.num_frames
            if num_frames_override:
                if num_frames_override > target_num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
811
812
813
                        num_frames_override,
                        target_num_frames,
                    )
814
815
816
                if num_frames_override < 2:
                    logger.warning(
                        "video.num_frames override (%d) cannot be less "
817
818
819
                        "than 2, will be ignored",
                        num_frames_override,
                    )
820
821
822
                target_num_frames = min(target_num_frames, num_frames_override)
        target_num_frames = max(target_num_frames, 2)

823
824
825
826
827
828
829
830
831
        video_processor = self.info.get_video_processor()
        video_max_pixels = video_processor.size["longest_edge"]
        # video_max_pixels contains the temporal compression factor,
        # so we divide by 2 to get the maximum number of image pixels.
        target_video_width, target_video_height = (
            self.info.get_image_size_with_most_features(
                max_pixels=video_max_pixels // video_processor.temporal_patch_size
            )
        )
832
        target_video_size, _ = self.info._get_vision_info(
833
834
            image_width=target_video_width,
            image_height=target_video_height,
835
            num_frames=target_num_frames,
836
            image_processor=video_processor,
837
        )
838
839
        # NOTE: we need to do this check here since Qwen3-VL resizes video
        # frames depending on how many frames there are.
840
841
842
843
        target_video_width, target_video_height = (
            target_video_size.width,
            target_video_size.height,
        )
844
845
846
847
        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            width_override = video_overrides.width
            if width_override:
848
                if width_override > target_video_width:
849
850
                    logger.warning(
                        "video.width override (%d) exceeds model's "
851
852
                        "maximum width (%d), will be ignored",
                        width_override,
853
                        target_video_width,
854
                    )
855
                target_video_width = min(target_video_width, width_override)
856
857
            height_override = video_overrides.height
            if height_override:
858
                if height_override > target_video_height:
859
860
861
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
862
                        height_override,
863
                        target_video_height,
864
                    )
865
                target_video_height = min(target_video_height, height_override)
866

867
        return {
868
            "image": self._get_dummy_images(
869
870
                width=target_image_width,
                height=target_image_height,
871
872
873
874
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
875
876
                width=target_video_width,
                height=target_video_height,
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[VideoItem]:
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
                "frames_indices": [i for i in range(num_frames)],
                "video_backend": "opencv",
                "do_sample_frames": False,
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)
        return video_items


906
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
907
    def _get_data_parser(self) -> MultiModalDataParser:
908
909
910
911
        return Qwen2VLMultiModalDataParser(
            self.info.get_hf_config().vision_config.spatial_merge_size,
            video_needs_metadata=True,
        )
912
913
914
915
916
917
918
919
920
921
922
923

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # Separate video processing from image processing. Because the videos
924
925
        # are processed into several image patches
        if videos := mm_data.pop("videos", []):
926
927
928
            video_grid_thw_lst = []
            pixel_values_videos_lst = []

929
            for item in videos:
930
931
932
933
934
935
936
937
938
939
940
941
942
943
                video_array, metadata = item

                # NOTE: @JJJYmmm new attr metadata.frames_indices indicates
                # the sampled frames indices of pre-sampled videos, which is
                # used to calculate the timestamps. Make sure that
                # do_sample_frames in mm_kwargs is false for presampled videos.

                # NOTE: a copy of is created to update do_sample_frames,
                # otherwise mm_hash for the object will be incorrect.
                video_mm_kwargs = dict(**mm_kwargs)
                if "do_sample_frames" not in video_mm_kwargs:
                    # qwen_vl_utils already has "do_sample_frames" in
                    # mm_kwargs, don't overwrite it.
                    video_mm_kwargs["do_sample_frames"] = metadata.get(
944
945
                        "do_sample_frames", False
                    )
946

947
948
949
                metadata = VideoMetadata(
                    **{k: metadata[k] for k in metadata if k != "do_sample_frames"}
                )
950
951
952
953
954
955
956
957
958
959
960
961

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
                video_mm_data["video_metadata"] = [[metadata]]

                video_outputs = super()._call_hf_processor(
                    prompt="<|vision_start|><|video_pad|><|vision_end|>",
                    mm_data=video_mm_data,
                    mm_kwargs=video_mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
                input_ids = video_outputs.pop("input_ids")
962
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
963
964
965
966
967
968
969
                prompt = prompt.replace(
                    "<|vision_start|><|video_pad|><|vision_end|>",
                    video_placeholder,
                    1,
                )

                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
970
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
995
996
997
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)
998
999
1000
1001
1002
1003
1004
1005

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1006
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        tokenizer = self.info.get_tokenizer()
        hf_config = self.info.get_hf_config()

        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        vision_end_token_id = hf_config.vision_end_token_id

        merge_length = image_processor.merge_size**2

        def get_image_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
            do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
            sampled_fps = hf_processor_mm_kwargs.get("fps")
            if is_list_of(sampled_fps, float):
                sampled_fps = sampled_fps[item_idx]
            timestamps = self.info._get_video_second_idx(
1035
1036
                metadata, out_item, do_sample_frames, sampled_fps
            )
1037
1038
1039

            assert len(timestamps) == grid_thw[0], (
                f"The timestamps length({len(timestamps)}) should be equal "
1040
1041
                f"video length ({grid_thw[0]})."
            )
1042
1043

            frames_idx_token = [
1044
                tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
1045
1046
                for curr_time in timestamps
            ]
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
            tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
            per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]

            video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
            if video_pruning_rate is not None and video_pruning_rate > 0.0:
                total_retained = compute_retained_tokens_count(
                    tokens_per_frame,
                    len(frames_idx_token),
                    video_pruning_rate,
                )
                if len(frames_idx_token) == 0:
                    per_frame_token_counts = []
                elif len(frames_idx_token) == 1:
                    per_frame_token_counts = [tokens_per_frame]
                else:
                    first_frame_tokens = tokens_per_frame
                    remaining_tokens = max(total_retained - first_frame_tokens, 0)
                    base = remaining_tokens // (len(frames_idx_token) - 1)
                    remainder = remaining_tokens % (len(frames_idx_token) - 1)
                    per_frame_token_counts = [first_frame_tokens]
                    for frame_idx in range(1, len(frames_idx_token)):
                        extra = base + (1 if (frame_idx - 1) < remainder else 0)
                        per_frame_token_counts.append(extra)

1071
            placeholder = []
1072
1073
1074
1075
1076
            for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
                placeholder.extend(timestamp_tokens)
                tokens_this_frame = per_frame_token_counts[
                    frame_idx if frame_idx < len(per_frame_token_counts) else -1
                ]
1077
1078
                placeholder.extend(
                    [vision_start_token_id]
1079
                    + [video_token_id] * tokens_this_frame
1080
1081
1082
                    + [vision_end_token_id]
                )
            return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_qwen3vl,
            ),
            # NOTE: We match string on purpose since searching sequence of
            # token ids takes more time.
            PromptReplacement(
                modality="video",
                target="<|vision_start|><|video_pad|><|vision_end|>",
                replacement=get_video_replacement_qwen3vl,
            ),
        ]


@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        # the same shape as input_embeds
1109
1110
1111
        "deepstack_input_embeds": 0,
    }
)
1112
1113
1114
1115
1116
class Qwen3LLMModel(Qwen3Model):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        if not get_pp_group().is_first_rank:
            assert self.start_layer >= len(
1117
1118
1119
1120
1121
                vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes
            ), (
                "start_layer should be greater than or equal to "
                "len(deepstack_visual_indexes)"
            )
1122
1123
1124

    def forward(
        self,
1125
        input_ids: torch.Tensor | None,
1126
        positions: torch.Tensor,
1127
1128
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1129
        # args for deepstack
1130
1131
        deepstack_input_embeds: IntermediateTensors | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1132
1133
1134
1135
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1136
                hidden_states = self.embed_input_ids(input_ids)
1137
1138
1139
1140
1141
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
1142

1143
        aux_hidden_states = []
1144
1145
        for layer_idx, layer in islice(
            enumerate(self.layers), self.start_layer, self.end_layer
1146
        ):
1147
1148
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
1149
1150
1151
1152
1153
1154
1155

            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

1156
1157
1158
1159
1160
1161
1162
            if deepstack_input_embeds is not None and layer_idx in range(
                0, len(deepstack_input_embeds)
            ):
                hidden_states = (
                    hidden_states
                    + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
                )
1163
1164

        if not get_pp_group().is_last_rank:
1165
1166
1167
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1168
        hidden_states, _ = self.norm(hidden_states, residual)
1169
1170
1171

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
        return hidden_states


class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super(Qwen3ForCausalLM, self).__init__()
        config = vllm_config.model_config.hf_config.text_config
        quant_config = vllm_config.quant_config

        self.config = config

        self.quant_config = quant_config
1184
1185
1186
        self.model = Qwen3LLMModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1187
1188
1189
1190
1191

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
1192
1193
1194
1195
1196
1197
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix="lm_head",
                )
1198
1199
1200
1201
1202
1203
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
1204
1205
            self.model.make_empty_intermediate_tensors
        )
1206
1207


1208
1209
1210
1211
1212
1213
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3VLMultiModalProcessor,
    info=Qwen3VLProcessingInfo,
    dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
1214
1215
1216
1217
1218
1219
    nn.Module,
    SupportsMultiModal,
    SupportsLoRA,
    SupportsPP,
    SupportsMRoPE,
    SupportsEagle3,
1220
    SupportsMultiModalPruning,
1221
):
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
1232
        "qkv": ["qkv"],  # For vision tower's already-packed QKV
1233
    }
1234
1235
1236

    supports_encoder_tp_data = True

1237
1238
1239
1240
1241
1242
    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.visual.": "visual.",
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
1243
1244
        }
    )
1245
1246

    @classmethod
1247
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: Qwen3VLConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1263
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1264
1265
1266
1267
1268
        self.video_pruning_rate = multimodal_config.video_pruning_rate
        self.is_multimodal_pruning_enabled = (
            multimodal_config.is_multimodal_pruning_enabled()
        )

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
        self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
        self.deepstack_num_level = (
            len(config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
        self.visual_dim = config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level

        with self._mark_tower_model(vllm_config, {"image", "video"}):
1279
1280
1281
1282
1283
1284
            self.visual = Qwen3_VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )
1285

1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
            # register buffer for deepstack
            if self.use_deepstack:
                self.deepstack_input_embeds = [
                    torch.zeros(
                        vllm_config.scheduler_config.max_num_batched_tokens,
                        config.text_config.hidden_size,
                    )
                    for _ in range(self.deepstack_num_level)
                ]

        with self._mark_language_model(vllm_config):
            self.language_model = Qwen3LLMForCausalLM(
                vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
            )
1300
1301

        self.make_empty_intermediate_tensors = (
1302
1303
            self.language_model.make_empty_intermediate_tensors
        )
1304

1305
1306
1307
1308
1309
1310
1311
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.language_model.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.language_model.model.layers)
        return (2, num_layers // 2, num_layers - 3)

1312
1313
1314
1315
1316
1317
1318
    def _get_deepstack_input_embeds(
        self,
        num_tokens: int,
    ) -> IntermediateTensors | None:
        if not getattr(self, "deepstack_input_embeds", None):
            return None  # If vision tower is skipped

1319
        # get deepstack_input_embeds from buffer, and clear the buffer
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
1330
1331
1332
        if not getattr(self, "deepstack_input_embeds", None):
            return

1333
1334
1335
1336
        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
1337
1338
1339
1340
1341
1342
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
1343
1344
1345
1346
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
1347
1348
                deepstack_input_embeds[idx]
            )
1349
1350

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
1351
1352
1353
        if not getattr(self, "deepstack_input_embeds", None):
            return

1354
1355
1356
1357
1358
1359
        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

    def _parse_and_validate_image_input(
1360
        self, **kwargs: object
1361
    ) -> Qwen2_5_VLImageInputs | None:
1362
1363
1364
1365
1366
1367
1368
1369
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
1370
1371
1372
1373
1374
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1375
1376
1377
1378
1379

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
1380
1381
                image_grid_thw=image_grid_thw,
            )
1382
1383

    def _parse_and_validate_video_input(
1384
        self, **kwargs: object
1385
    ) -> Qwen2_5_VLVideoInputs | None:
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
1406
1407
                video_grid_thw=video_grid_thw,
            )
1408
1409

    def _process_image_input(
1410
1411
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1412
1413
1414
1415
1416
1417
1418
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1419
            if self.use_data_parallel:
1420
                return run_dp_sharded_mrope_vision_model(
1421
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1422
                )
1423
            else:
1424
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1425
1426
1427

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1428
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1429
1430
1431
        return image_embeds.split(sizes)

    def _process_video_input(
1432
1433
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1434
1435
1436
1437
1438
1439
1440
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1441
1442
                self.visual.dtype
            )
1443
            if self.use_data_parallel:
1444
                grid_thw_list = grid_thw.tolist()
1445
1446
1447
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1448
            else:
1449
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1450
1451
1452

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1453
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1454
1455
        return video_embeds.split(sizes)

1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
    def _postprocess_image_embeds_evs(
        self,
        image_embeds_split: tuple[torch.Tensor, ...],
        image_input: Qwen2_5_VLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Append mrope positions for each for images.
        This is necessary to recover correct mrope
        positions after video pruning

        Args:
            image_embeds_split: Tuple of image embeddings for
                each image item.
            image_input: Image input data.

        Returns:
            Tuple of image embeddings for each image item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        merge_size = self.visual.spatial_merge_size
        grid_thw = image_input["image_grid_thw"]
        grid_thw_list = grid_thw.tolist()
        image_embeds_out = []
        for emb, size in zip(image_embeds_split, grid_thw_list):
            positions = compute_mrope_for_media(size, merge_size).to(emb.device)
            emb = torch.cat([emb, positions], dim=1)
            image_embeds_out.append(emb)
        image_embeds_split = image_embeds_out
        return tuple(image_embeds_split)

    def _postprocess_video_embeds_evs(
        self,
        video_embeds_split: tuple[torch.Tensor, ...],
        video_input: Qwen2_5_VLVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Prunes video embeddings via Efficient Video Sampling (EVS)
        and then appends mrope positions for each retained embeddings

        Args:
            video_embeds_split: Tuple of video embeddings for each video item.
            video_input: Video input data.

        Returns:
            Tuple of video embeddings for each video item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()
        merge_size = self.visual.spatial_merge_size

        # Cast to long to match the original code
        # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
        second_per_grid_ts = video_input.get("second_per_grid_ts")
        if second_per_grid_ts is None:
            # For Qwen3-VL, second_per_grid_ts might not be available
            # Use default value of 1.0 for each video
            second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
        else:
            second_per_grid_ts = second_per_grid_ts.long()
        tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)

        video_embeds_out = []
        for emb, size, video_second_per_grid_t in zip(
            video_embeds_split, grid_thw_list, second_per_grid_ts
        ):
            # For each video, we compute retention mask using EVS
            retention_mask = compute_retention_mask(
                emb,
                size,
                spatial_merge_size=self.visual.spatial_merge_size,
                q=self.video_pruning_rate,
            )

            # Debug logging for EVS pruning
            logger.debug(
                "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
                "pruning_rate=%.2f, reduction=%.1f%%)",
                emb.shape[0],
                retention_mask.sum().item(),
                size[0],
                size[1],
                size[2],
                self.video_pruning_rate,
                (1 - retention_mask.float().mean().item()) * 100,
            )

            positions = compute_mrope_for_media(
                size,
                merge_size,
                tokens_per_second=tokens_per_second,
                video_second_per_grid=video_second_per_grid_t.item(),
            ).to(emb.device)

            emb = emb[retention_mask]
            positions = positions[retention_mask]
            emb = torch.cat([emb, positions], dim=1)
            video_embeds_out.append(emb)
        return tuple(video_embeds_out)

1559
1560
1561
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}
        for input_key in kwargs:
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
1576
1577
        return mm_input_by_modality

1578
1579
1580
    def iter_mm_grid_hw(
        self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int]]:
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
        """
        Iterate over multimodal features and yield grid information.

        For videos with EVS (Efficient Video Sampling) enabled, this function
        computes the offset based on the pruned token count rather than relying
        on input_tokens.index(), which would fail when tokens are pruned.

        Args:
            input_tokens: List of token IDs in the prompt
            mm_features: List of multimodal feature specifications

        Yields:
            Tuple of (offset, grid_h, grid_w) for each frame/image
        """
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
        video_token_id = self.config.video_token_id
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, h // spatial_merge_size, w // spatial_merge_size
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637

                # Check if EVS (Efficient Video Sampling) is enabled
                is_evs_enabled = (
                    hasattr(self, "video_pruning_rate")
                    and self.video_pruning_rate is not None
                    and self.video_pruning_rate > 0.0
                )

                if is_evs_enabled:
                    frame_offsets = self._extract_frame_offsets_from_mask(
                        mm_feature.mm_position, t
                    )
                    if frame_offsets is not None:
                        for rel_offset in frame_offsets:
                            yield offset + rel_offset, llm_grid_h, llm_grid_w
                        continue

                    # If EVS is enabled but mask is missing, this indicates a bug
                    # in the prompt processing pipeline. The is_embed mask should
                    # always be present when video_pruning_rate > 0.
                    raise RuntimeError(
                        f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
                        "but is_embed mask is missing from mm_position. "
                        "This indicates a bug in prompt processing."
                    )
                else:
                    # Non-EVS mode: Use original logic with input_tokens.index()
                    for _ in range(t):
                        offset = input_tokens.index(video_token_id, offset)
                        yield offset, llm_grid_h, llm_grid_w
                        offset += llm_grid_h * llm_grid_w
1638
1639
1640
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
    def _get_evs_mask_segments(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[torch.Tensor] | None:
        """Extract contiguous segments from EVS is_embed mask.

        The EVS (Efficient Video Sampling) mask marks which placeholder
        positions should be filled with video embeddings. This method splits
        the mask into contiguous segments, where each segment represents one
        retained frame.

        This is a pure function - it does not modify any state and always
        returns the same output for the same input (idempotent).

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frame segments

        Returns:
            List of tensors, each containing indices for one frame segment,
            or None if EVS is not enabled or validation fails.
        """
        is_embed_mask = getattr(mm_position, "is_embed", None)
        if is_embed_mask is None:
            return None

        # Find all True positions in the mask
        mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
        true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
        if true_indices.numel() == 0:
            return None

        # Split into contiguous segments (where diff > 1 indicates a gap)
        if true_indices.numel() == 1:
            segments = [true_indices]
        else:
            diffs = torch.diff(true_indices)
            split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
            if split_points.numel() == 0:
                segments = [true_indices]
            else:
                segments = torch.tensor_split(
                    true_indices, split_points.add(1).tolist()
                )

        # Validate segment count matches expected frames
        if len(segments) < expected_frames:
            logger.debug(
                "EVS mask segments (%d) do not match expected frames (%d)",
                len(segments),
                expected_frames,
            )
            return None

        return segments[:expected_frames]

    def _extract_frame_offsets_from_mask(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[int] | None:
        """Return relative offsets for each EVS-retained frame.

        The prompt processor stores a boolean mask inside ``mm_position`` that
        marks which placeholder locations should be populated with video
        embeddings. By splitting that mask into contiguous runs we can recover
        the start of every retained frame without probing ``input_tokens``.

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frames

        Returns:
            List of starting offsets (relative to mm_position) for each frame,
            or None if EVS is not enabled.
        """
        segments = self._get_evs_mask_segments(mm_position, expected_frames)
        if segments is None:
            return None

        return [int(segment[0].item()) for segment in segments]

    def _get_actual_frame_token_counts(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[int] | None:
        """Return actual token count for each EVS-retained frame.

        This function calculates the actual number of tokens per frame by
        analyzing the is_embed mask, accounting for EVS pruning. Each frame
        may have a different token count due to content-aware pruning.

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frames

        Returns:
            List of token counts for each frame, or None if EVS is not enabled.
        """
        segments = self._get_evs_mask_segments(mm_position, expected_frames)
        if segments is None:
            return None

        return [len(seg) for seg in segments]

    def recompute_mrope_positions(
        self,
        input_ids: list[int],
        multimodal_embeddings: tuple[torch.Tensor, ...],
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
            input_ids: (N,) All input tokens of the prompt (Containing
                entire sequence).
            multimodal_embeddings: Tuple of multimodal embeddings.
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
        image_token_id = self.config.image_token_id
        video_token_id = self.config.video_token_id
        vision_start_token_id = self.config.vision_start_token_id

        # Device
        device = (
            multimodal_embeddings[0].device
            if len(multimodal_embeddings)
            else mrope_positions.device
        )

        # Tensors
        input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

        mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
        mm_embeddings_pos = [
            mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
        ]

        positions, mrope_positions_delta = recompute_mrope_positions(
            input_ids_t,
            mm_embeddings_pos,
            mrope_positions,
            num_computed_tokens,
            vision_start_token_id,
            image_token_id,
            video_token_id,
        )

        return tuple(mm_embeddings_out), positions, mrope_positions_delta

1799
    def get_mrope_input_positions(
1800
        self,
1801
        input_tokens: list[int],
1802
        mm_features: list[MultiModalFeatureSpec],
1803
    ) -> tuple[torch.Tensor, int]:
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
        # Pre-collect actual frame token counts for EVS mode
        frame_token_counts_map = {}
        for mm_feature in mm_features:
            if mm_feature.modality == "video":
                is_evs_enabled = (
                    hasattr(self, "video_pruning_rate")
                    and self.video_pruning_rate is not None
                    and self.video_pruning_rate > 0.0
                )
                if is_evs_enabled:
                    t = mm_feature.data["video_grid_thw"].data.tolist()[0]
                    token_counts = self._get_actual_frame_token_counts(
                        mm_feature.mm_position, t
                    )
                    assert token_counts is not None, (
                        "EVS enabled but failed to extract frame token counts "
                        "from is_embed mask"
                    )
                    frame_token_counts_map[mm_feature.mm_position.offset] = token_counts

1824
        llm_pos_ids_list = []
1825
        st = 0
1826
1827
        frame_counts_idx = {}

1828
1829
1830
1831
        for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
            input_tokens, mm_features
        ):
            text_len = offset - st
1832
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863

            # Determine actual token count for this frame
            base_offset = None
            for feat_offset in frame_token_counts_map:
                if offset >= feat_offset:
                    base_offset = feat_offset

            if base_offset is not None:
                # EVS mode: use actual token count from is_embed mask
                assert base_offset in frame_token_counts_map, (
                    f"Found base_offset {base_offset} but not in frame_token_counts_map"
                )

                if base_offset not in frame_counts_idx:
                    frame_counts_idx[base_offset] = 0

                counts = frame_token_counts_map[base_offset]
                idx = frame_counts_idx[base_offset]

                assert idx < len(counts), (
                    f"EVS frame index {idx} out of range (total frames: {len(counts)})"
                )

                actual_frame_tokens = counts[idx]
                frame_counts_idx[base_offset] += 1
            else:
                # Non-EVS mode (or image): use theoretical grid size
                actual_frame_tokens = llm_grid_h * llm_grid_w

            # Add text segment
            text_positions = (
1864
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1865
            )
1866
1867
            llm_pos_ids_list.append(text_positions)
            st_idx += text_len
1868

1869
            # Add frame segment with actual token count (not theoretical)
1870
            grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
1871
1872
1873
            # Only take the first actual_frame_tokens positions
            frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
            llm_pos_ids_list.append(frame_positions)
1874

1875
1876
            # Update st using actual token count
            st = offset + actual_frame_tokens
1877

1878
        # Handle final text segment
1879
1880
1881
        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
1882
            final_text_positions = (
1883
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1884
            )
1885
            llm_pos_ids_list.append(final_text_positions)
1886

1887
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
1888
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1889

1890
        return torch.from_numpy(llm_positions), mrope_position_delta
1891

1892
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
1893
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
1906
                image_embeddings = self._process_image_input(multimodal_input)
1907
1908
1909
1910
                if self.is_multimodal_pruning_enabled:
                    image_embeddings = self._postprocess_image_embeds_evs(
                        image_embeddings, multimodal_input
                    )
1911
                multimodal_embeddings += tuple(image_embeddings)
1912
1913
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
1914
1915
1916
1917
                if self.is_multimodal_pruning_enabled:
                    video_embeddings = self._postprocess_video_embeds_evs(
                        video_embeddings, multimodal_input
                    )
1918
                multimodal_embeddings += tuple(video_embeddings)
1919
1920
1921
        return multimodal_embeddings

    def _compute_deepstack_embeds(
1922
1923
1924
1925
1926
1927
        self,
        inputs_embeds: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings,
        is_multimodal: torch.Tensor,
    ) -> tuple[torch.Tensor, MultiModalEmbeddings]:
        visual_lens = [len(x) for x in multimodal_embeddings]
1928
1929
        multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

1930
1931
1932
1933
1934
1935
1936
1937
        (
            multimodal_embeddings_main,
            multimodal_embeddings_multiscale,
        ) = torch.split(
            multimodal_embeddings_cat,
            [self.visual_dim, self.multiscale_dim],
            dim=-1,
        )
1938

1939
1940
1941
        multimodal_embeddings = torch.split(
            multimodal_embeddings_main, visual_lens, dim=0
        )
1942
        multimodal_embeddings_multiscale = torch.split(
1943
1944
            multimodal_embeddings_multiscale, visual_lens, dim=0
        )
1945
1946

        deepstack_input_embeds = inputs_embeds.new_zeros(
1947
1948
            inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
        )
1949

1950
1951
1952
1953
        deepstack_input_embeds = _merge_multimodal_embeddings(
            inputs_embeds=deepstack_input_embeds,
            multimodal_embeddings=multimodal_embeddings_multiscale,
            is_multimodal=is_multimodal,
1954
1955
        )
        deepstack_input_embeds = deepstack_input_embeds.view(
1956
1957
            inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
        )
1958
        deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
1959

1960
1961
        return deepstack_input_embeds, multimodal_embeddings

1962
    def embed_input_ids(
1963
1964
        self,
        input_ids: torch.Tensor,
1965
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1966
        *,
1967
        is_multimodal: torch.Tensor | None = None,
1968
        handle_oov_mm_token: bool = False,
1969
    ) -> torch.Tensor:
1970
        inputs_embeds = self._embed_text_input_ids(
1971
            input_ids,
1972
            self.language_model.embed_input_ids,
1973
1974
1975
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
1976

1977
1978
        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds
1979

1980
        is_multimodal = _require_is_multimodal(is_multimodal)
1981
1982

        if self.use_deepstack:
1983
1984
1985
1986
1987
1988
1989
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=inputs_embeds,
                multimodal_embeddings=multimodal_embeddings,
                is_multimodal=is_multimodal,
1990
            )
1991
1992
        else:
            deepstack_input_embeds = None
1993

1994
1995
1996
1997
1998
        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
1999

2000
        if deepstack_input_embeds is not None:
2001
2002
2003
2004
2005
2006
            self._set_deepstack_input_embeds(deepstack_input_embeds)

        return inputs_embeds

    def forward(
        self,
2007
        input_ids: torch.Tensor | None,
2008
        positions: torch.Tensor,
2009
2010
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
2011
        **kwargs: object,
2012
    ) -> torch.Tensor | IntermediateTensors:
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
        """Run forward pass for Qwen3VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen3VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
            intermediate_tensors: Intermediate tensors from previous pipeline
                stages.
            inputs_embeds: Pre-computed input embeddings.
            **kwargs: Additional keyword arguments including:
                - pixel_values: Pixel values to be fed to a model.
                    `None` if no images are passed.
                - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
                    LLM. `None` if no images are passed.
                - pixel_values_videos: Pixel values of videos to be fed to a
                    model. `None` if no videos are passed.
                - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
                    LLM. `None` if no videos are passed.
2035
2036
2037
2038
2039
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

2040
        if inputs_embeds is not None and get_pp_group().is_first_rank:
2041
            deepstack_input_embeds = self._get_deepstack_input_embeds(
2042
2043
                inputs_embeds.size(0)
            )
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
2064
    ) -> torch.Tensor | None:
2065
        return self.language_model.compute_logits(hidden_states)
2066

2067
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
2068
        loader = AutoWeightsLoader(self)
2069
2070
2071
2072
2073
2074
2075
2076
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
2077
            connector=["visual.merger", "visual.deepstack_merger_list"],
2078
            tower_model="visual.",
2079
        )
2080

2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2