qwen3_vl.py 80 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
26

27
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
28
from functools import lru_cache, partial
29
from itertools import islice
30
from typing import Any
31
32
33
34
35

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
36
from transformers import BatchFeature
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
38
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
39
40
41
    smart_resize as image_smart_resize,
)
from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor
42
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
43
44
45
    Qwen3VLConfig,
    Qwen3VLVisionConfig,
)
46
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
47
48
    smart_resize as video_smart_resize,
)
49
50
from transformers.video_utils import VideoMetadata

51
from vllm.attention.backends.registry import AttentionBackendEnum
52
from vllm.compilation.decorators import support_torch_compile
53
from vllm.config import MultiModalConfig, VllmConfig
54
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
55
56
57
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
58
from vllm.model_executor.layers.conv import Conv3dLayer
59
60
61
62
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
63
64
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
65
from vllm.model_executor.layers.rotary_embedding import get_rope
66
67
68
69
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
70
71
72
73
74
75
from vllm.multimodal.evs import (
    compute_mrope_for_media,
    compute_retained_tokens_count,
    compute_retention_mask,
    recompute_mrope_positions,
)
76
77
from vllm.multimodal.inputs import (
    MultiModalDataDict,
78
    MultiModalFeatureSpec,
79
80
81
    MultiModalFieldConfig,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
82
    PlaceholderRange,
83
84
85
86
87
88
89
90
91
    VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
92
93
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
94
from vllm.utils.collection_utils import is_list_of
95

96
97
from .interfaces import (
    MultiModalEmbeddings,
98
    SupportsEagle3,
99
    SupportsLoRA,
100
    SupportsMRoPE,
101
    SupportsMultiModal,
102
    SupportsMultiModalPruning,
103
    SupportsPP,
104
    _require_is_multimodal,
105
106
107
108
109
110
111
112
113
114
)
from .qwen2_5_vl import (
    Qwen2_5_VisionAttention,
    Qwen2_5_VLImageEmbeddingInputs,
    Qwen2_5_VLImageInputs,
    Qwen2_5_VLImagePixelInputs,
    Qwen2_5_VLVideoEmbeddingInputs,
    Qwen2_5_VLVideoInputs,
    Qwen2_5_VLVideoPixelInputs,
)
115
from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo
116
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
117
118
119
120
121
122
123
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    _merge_multimodal_embeddings,
    maybe_prefix,
)
124
125
126
127
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
128
129
130

logger = init_logger(__name__)

131
132
133
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

class Qwen3_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        hidden_size: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
149
150
        self.proj = Conv3dLayer(
            in_channels,
151
            hidden_size,
152
153
            kernel_size=kernel_size,
            stride=kernel_size,
154
155
            bias=True,
        )
156
157

    def forward(self, x: torch.Tensor) -> torch.Tensor:
158
159
160
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.hidden_size)
161
162
163
164
        return x


class Qwen3_VisionMLP(nn.Module):
165
166
167
168
169
170
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
171
        quant_config: QuantizationConfig | None = None,
172
        multimodal_config: MultiModalConfig | None = None,
173
174
        prefix: str = "",
    ):
175
        super().__init__()
176
177
178
179
180
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        self.linear_fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
        self.linear_fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        self.act_fn = act_fn

    def forward(self, x: torch.Tensor):
        mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return mlp_output


class Qwen3_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
213
        norm_layer: Callable[[int], nn.Module] | None = None,
214
        multimodal_config: MultiModalConfig | None = None,
215
        quant_config: QuantizationConfig | None = None,
216
217
218
219
220
221
222
        prefix: str = "",
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
223
224
225
226
227
        self.attn = Qwen2_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
228
            multimodal_config=multimodal_config,
229
            prefix=f"{prefix}.attn",
230
231
232
233
234
235
236
        )
        self.mlp = Qwen3_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_fn=act_fn,
            bias=True,
            quant_config=quant_config,
237
            multimodal_config=multimodal_config,
238
239
            prefix=f"{prefix}.mlp",
        )
240
241

    def forward(
242
243
244
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
245
246
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
247
        max_seqlen: torch.Tensor,  # Only used for Flash Attention
248
    ) -> torch.Tensor:
249
250
251
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
252
253
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
254
255
            max_seqlen=max_seqlen,
        )
256
257
258
259
260
261
262
263
264
265

        x = x + self.mlp(self.norm2(x))
        return x


class Qwen3_VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
266
        norm_layer: Callable[[int], nn.Module] | None = None,
267
268
        spatial_merge_size: int = 2,
        use_postshuffle_norm: bool = False,
269
        quant_config: QuantizationConfig | None = None,
270
        multimodal_config: MultiModalConfig | None = None,
271
272
273
        prefix: str = "",
    ) -> None:
        super().__init__()
274
275
276
277
278
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
279
280
281
282
283
284
285
286
        self.hidden_size = context_dim * (spatial_merge_size**2)

        self.use_postshuffle_norm = use_postshuffle_norm
        if self.use_postshuffle_norm:
            context_dim = self.hidden_size

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
287
        self.norm = norm_layer(context_dim)
288
289
290
291
292
293
294
295
        self.linear_fc1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
296
        self.act_fn = nn.GELU()
297
298
299
300
301
302
303
304
        self.linear_fc2 = RowParallelLinear(
            self.hidden_size,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_postshuffle_norm:
            x = self.norm(x.view(-1, self.hidden_size))
        else:
            x = self.norm(x).view(-1, self.hidden_size)

        x_parallel, _ = self.linear_fc1(x)
        x_parallel = self.act_fn(x_parallel)
        out, _ = self.linear_fc2(x_parallel)
        return out


class Qwen3_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen3VLVisionConfig,
        norm_eps: float = 1e-6,
323
        quant_config: QuantizationConfig | None = None,
324
        multimodal_config: MultiModalConfig | None = None,
325
326
327
328
329
330
331
332
333
334
335
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.num_position_embeddings = vision_config.num_position_embeddings
        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.spatial_merge_unit = self.spatial_merge_size**2
        self.temporal_patch_size = vision_config.temporal_patch_size
        self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
336
        self.num_grid_per_side = int(self.num_position_embeddings**0.5)
337
338
339

        # NOTE: This is used for creating empty tensor for all_gather for
        # DP ViT. Here out_hidden_size is enlarged due to deepstack
340
341
342
        self.out_hidden_size = vision_config.out_hidden_size * (
            1 + len(self.deepstack_visual_indexes)
        )
343
344
345
346
347
348
349
350

        self.patch_embed = Qwen3_VisionPatchEmbed(
            patch_size=self.patch_size,
            temporal_patch_size=self.temporal_patch_size,
            in_channels=vision_config.in_channels,
            hidden_size=self.hidden_size,
        )

351
        self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
352
353
354

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
355
356
357
358
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
359
            rope_parameters={"partial_rotary_factor": 0.5},
360
        )
361
362
363
364
365
366
367

        self.merger = Qwen3_VisionPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=self.hidden_size,
            norm_layer=norm_layer,
            spatial_merge_size=self.spatial_merge_size,
            quant_config=quant_config,
368
            multimodal_config=multimodal_config,
369
370
371
            prefix=f"{prefix}.merger",
        )

372
373
374
375
376
377
378
379
380
        self.deepstack_merger_list = nn.ModuleList(
            [
                Qwen3_VisionPatchMerger(
                    d_model=vision_config.out_hidden_size,
                    context_dim=self.hidden_size,
                    spatial_merge_size=self.spatial_merge_size,
                    use_postshuffle_norm=True,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
381
                    multimodal_config=multimodal_config,
382
383
384
385
386
                    prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
                )
                for layer_idx in range(len(self.deepstack_visual_indexes))
            ]
        )
387

388
389
390
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend if multimodal_config else None
        )
391
        self.attn_backend = get_vit_attn_backend(
392
393
394
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
395
        )
396
397

        if self.attn_backend not in {
398
399
400
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
401
402
        }:
            raise RuntimeError(
403
404
405
406
407
408
409
410
411
412
413
                f"Qwen3-VL does not support {self.attn_backend} backend now."
            )
        self.blocks = nn.ModuleList(
            [
                Qwen3_VisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.intermediate_size,
                    act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
                    norm_layer=norm_layer,
                    quant_config=quant_config,
414
                    multimodal_config=multimodal_config,
415
416
417
418
419
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(vision_config.depth)
            ]
        )
420
421
422
423
424
425
426
427
428

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    @staticmethod
    @lru_cache(maxsize=1024)
    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
        h_div = h // spatial_merge_size
        w_div = w // spatial_merge_size
        hpos_ids = hpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
        wpos_ids = wpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))

456
457
    def rot_pos_emb(self, grid_thw: list[list[int]]):
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
458
459
460
461
462
463
        pos_ids = [
            self.rot_pos_ids(h, w, self.spatial_merge_size)
            if t == 1
            else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
            for t, h, w in grid_thw
        ]
464
        pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
465
466
467
468

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

469
470
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
471
472

        return cos_combined, sin_combined
473

474
    def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
475
476
477
        num_grid_per_side = self.num_grid_per_side
        m_size = self.spatial_merge_size
        hidden_dim = self.pos_embed.embedding_dim
478

479
        outputs = []
480
        for t, h, w in grid_thw:
481
482
483
484
485
486
            h_idxs = torch.linspace(
                0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
            )
            w_idxs = torch.linspace(
                0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
            )
487
488
489
490
491
492
493
494
495

            h_floor = h_idxs.to(torch.long)
            w_floor = w_idxs.to(torch.long)
            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

            dh = h_idxs - h_floor
            dw = w_idxs - w_floor

496
            # Create meshgrid view for all h, w vars
497
498
499
            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
500
501
502
503
504
505
506
507
508
509
510

            # original computation of weights
            # w00 = (1 - dh_grid) * (1 - dw_grid)
            # w01 = (1 - dh_grid) * dw_grid
            # w10 = dh_grid * (1 - dw_grid)
            # w11 = dh_grid * dw_grid
            # we reuse w11 here to avoid duplicate
            # dh_grid * dw_grid computation
            w11 = dh_grid * dw_grid
            w10 = dh_grid - w11
            w01 = dw_grid - w11
511
            w00 = 1 - dh_grid - w01
512

513
514
515
            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
            h_grid_idx = h_grid * num_grid_per_side
516

517
            indices = (h_grid_idx + w_grid).reshape(4, -1)
518
            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
519
            weights = weights.to(dtype=self.dtype)
520
521

            embeds = self.pos_embed(indices)
522
523
            embeds *= weights
            combined = embeds.sum(dim=0)
524

525
526
            combined = combined.reshape(
                h // m_size, m_size, w // m_size, m_size, hidden_dim
527
            )
528
529
            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
530
531
532
            outputs.append(repeated)

        return torch.cat(outputs, dim=0)
533
534
535
536

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
537
    ) -> torch.Tensor:
538
        max_seqlen = torch.zeros([], device=cu_seqlens.device)
539
        if (
540
541
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
542
        ):
543
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
544
        return max_seqlen
545
546
547
548

    def forward(
        self,
        x: torch.Tensor,
549
        grid_thw: torch.Tensor | list[list[int]],
550
    ) -> torch.Tensor:
551
        hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
552
553
        hidden_states = self.patch_embed(hidden_states)

554
555
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
556
            grid_thw = np.array(grid_thw, dtype=np.int32)
557
558
        else:
            grid_thw_list = grid_thw.tolist()
559
            grid_thw = grid_thw.numpy()
560
561

        pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
562
        hidden_states = hidden_states + pos_embeds
563
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
564

565
566
567
568
569
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
570
571

        hidden_states = hidden_states.unsqueeze(1)
572
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
573
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
574
575
576

        deepstack_feature_lists = []
        for layer_num, blk in enumerate(self.blocks):
577
578
579
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
580
581
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
582
583
                max_seqlen=max_seqlen,
            )
584
            if layer_num in self.deepstack_visual_indexes:
585
586
587
588
                deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
                deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](
                    hidden_states
                )
589
590
591
                deepstack_feature_lists.append(deepstack_feature)
        hidden_states = self.merger(hidden_states)
        hidden_states = torch.cat(
592
593
            [hidden_states] + deepstack_feature_lists, dim=1
        )  # [seq_len, hidden_size * (1 + depth_of_deepstack)]
594
595
        return hidden_states

596
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
597
598
599
600
601
602
603
604
605
606
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
607
            for param_name, weight_name, shard_id in stacked_params_mapping:
608
609
610
611
612
613
614
615
616
617
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
618
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Qwen3VLConfig)

    def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
        return self.ctx.get_hf_processor(
            Qwen3VLProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )

635
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
636
637
638
639
640
641
642
643
644
645
646
647
        return self.get_hf_processor(**kwargs).image_processor

    def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 2,
        do_resize: bool = True,
648
        image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None,
649
    ) -> tuple[ImageSize, int]:
650
651
652
        if image_processor is None and num_frames > 1:
            image_processor = self.get_video_processor()
        elif image_processor is None:
653
654
            image_processor = self.get_image_processor()

655
656
        is_video = isinstance(image_processor, Qwen3VLVideoProcessor)

657
658
659
660
661
662
663
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size

        if do_resize:
664
665
666
667
            if is_video:
                smart_resize = video_smart_resize
                extra_kwargs = {
                    "num_frames": num_frames,
668
                    "temporal_factor": temporal_patch_size,
669
670
671
672
                }
            else:
                smart_resize = image_smart_resize
                extra_kwargs = {}
673
674
675
676
677
678
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.size["shortest_edge"],
                max_pixels=image_processor.size["longest_edge"],
679
                **extra_kwargs,
680
            )
681
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
682
        else:
683
            preprocessed_size = ImageSize(width=image_width, height=image_height)
684
685
686
687
688
689
690
691
692
693
694
695

        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

696
697
698
699
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int:
        return super()._get_max_video_frames(
            max_tokens, start_num_frames=start_num_frames
        )
700
701
702
703
704
705
706

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        return super().get_num_frames_with_most_features(
707
708
            seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO
        )
709
710
711
712
713
714
715

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
716
        num_video_soft_tokens = self.get_num_video_tokens(
717
718
            image_width=target_width,
            image_height=target_height,
719
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
720
721
            image_processor=None,
        )
722
        return num_video_soft_tokens
723

724
725
726
    def _calculate_timestamps(
        self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
    ):
727
728
729
730
        if not isinstance(indices, list):
            indices = indices.tolist()
        if len(indices) % merge_size != 0:
            # don't update metadata's frames_indices directly
731
            indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size)
732
        timestamps = [idx / video_fps for idx in indices]
733
734
735
736
        timestamps = [
            (timestamps[i] + timestamps[i + merge_size - 1]) / 2
            for i in range(0, len(timestamps), merge_size)
        ]
737
738
739
        return timestamps

    def _get_video_second_idx(
740
741
742
        self,
        metadata: dict[str, Any],
        out_item: MultiModalKwargsItem,
743
744
        do_sample_frames: bool | None = None,
        sampled_fps: float | None = None,
745
    ) -> list[int]:
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
        video_processor = self.get_video_processor()
        merge_size = video_processor.merge_size
        indices = metadata["frames_indices"]

        # metadata["fps"] refers to the true fps of the input video.
        video_fps = metadata["fps"]
        if do_sample_frames is None:
            do_sample_frames = metadata.get("do_sample_frames", False)

        # If video frames are sampled in HF processor (instead of vLLM
        # video loader), we need to re-calculate the indices from original
        # metadata.
        if do_sample_frames:
            # here video_fps is the fps of the sampled video, and
            # metadata["fps"] refers to the fps of the original video.
761
            sampled_fps = sampled_fps if sampled_fps else video_processor.fps
762
            total_num_frames = metadata["total_num_frames"]
763
            num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
764
            num_frames = min(
765
766
767
768
769
770
771
772
773
774
775
776
                min(
                    max(num_frames, video_processor.min_frames),
                    video_processor.max_frames,
                ),
                total_num_frames,
            )
            indices = (
                np.linspace(0, total_num_frames - 1, num_frames)
                .round()
                .astype(int)
                .tolist()
            )
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        timestamps = self._calculate_timestamps(indices, video_fps, merge_size)
        return timestamps


class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_token = "<|vision_start|><|image_pad|><|vision_end|>"
        video_token = "<|vision_start|><|video_pad|><|vision_end|>"

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
795
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
796
797
798
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
799
800
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None
801

802
        target_width, target_height = self.info.get_image_size_with_most_features()
803
        target_num_frames = self.info.get_num_frames_with_most_features(
804
805
            seq_len, mm_counts
        )
806
807
808
809
810
811
812
813
814

        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            num_frames_override = video_overrides.num_frames
            if num_frames_override:
                if num_frames_override > target_num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
815
816
817
                        num_frames_override,
                        target_num_frames,
                    )
818
819
820
                if num_frames_override < 2:
                    logger.warning(
                        "video.num_frames override (%d) cannot be less "
821
822
823
                        "than 2, will be ignored",
                        num_frames_override,
                    )
824
825
826
                target_num_frames = min(target_num_frames, num_frames_override)
        target_num_frames = max(target_num_frames, 2)

827
828
829
830
831
832
        target_video_size, _ = self.info._get_vision_info(
            image_width=target_width,
            image_height=target_height,
            num_frames=target_num_frames,
            image_processor=self.info.get_video_processor(),
        )
833
834
835
836
837
838
839
840
841
842
        # NOTE: we need to do this check here since Qwen3-VL resizes video
        # frames depending on how many frames there are.
        width, height = target_video_size.width, target_video_size.height
        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            width_override = video_overrides.width
            if width_override:
                if width_override > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
843
844
845
846
                        "maximum width (%d), will be ignored",
                        width_override,
                        width,
                    )
847
848
849
850
851
852
853
                width = min(width, width_override)
            height_override = video_overrides.height
            if height_override:
                if height_override > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
854
855
856
                        height_override,
                        height,
                    )
857
                height = min(height, height_override)
858

859
        return {
860
861
862
863
864
865
866
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
867
868
                width=width,
                height=height,
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[VideoItem]:
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
                "frames_indices": [i for i in range(num_frames)],
                "video_backend": "opencv",
                "do_sample_frames": False,
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)
        return video_items


898
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
899
    def _get_data_parser(self) -> MultiModalDataParser:
900
901
902
903
        return Qwen2VLMultiModalDataParser(
            self.info.get_hf_config().vision_config.spatial_merge_size,
            video_needs_metadata=True,
        )
904
905
906
907
908
909
910
911
912
913
914
915

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # Separate video processing from image processing. Because the videos
916
917
        # are processed into several image patches
        if videos := mm_data.pop("videos", []):
918
919
920
            video_grid_thw_lst = []
            pixel_values_videos_lst = []

921
            for item in videos:
922
923
924
925
926
927
928
929
930
931
932
933
934
935
                video_array, metadata = item

                # NOTE: @JJJYmmm new attr metadata.frames_indices indicates
                # the sampled frames indices of pre-sampled videos, which is
                # used to calculate the timestamps. Make sure that
                # do_sample_frames in mm_kwargs is false for presampled videos.

                # NOTE: a copy of is created to update do_sample_frames,
                # otherwise mm_hash for the object will be incorrect.
                video_mm_kwargs = dict(**mm_kwargs)
                if "do_sample_frames" not in video_mm_kwargs:
                    # qwen_vl_utils already has "do_sample_frames" in
                    # mm_kwargs, don't overwrite it.
                    video_mm_kwargs["do_sample_frames"] = metadata.get(
936
937
                        "do_sample_frames", False
                    )
938

939
940
941
                metadata = VideoMetadata(
                    **{k: metadata[k] for k in metadata if k != "do_sample_frames"}
                )
942
943
944
945
946
947
948
949
950
951
952
953

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
                video_mm_data["video_metadata"] = [[metadata]]

                video_outputs = super()._call_hf_processor(
                    prompt="<|vision_start|><|video_pad|><|vision_end|>",
                    mm_data=video_mm_data,
                    mm_kwargs=video_mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
                input_ids = video_outputs.pop("input_ids")
954
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
955
956
957
958
959
960
961
                prompt = prompt.replace(
                    "<|vision_start|><|video_pad|><|vision_end|>",
                    video_placeholder,
                    1,
                )

                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
962
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_grid_sizes = image_grid_thw.prod(-1)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
995
996
                "image", image_grid_sizes
            ),
997
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
998
999
                "image", image_grid_sizes
            ),
1000
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
1001
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
1002
1003
                "video", video_grid_sizes
            ),
1004
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
1005
1006
                "video", video_grid_sizes
            ),
1007
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
1008
1009
1010
1011
1012
1013
1014
1015
1016
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1017
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        tokenizer = self.info.get_tokenizer()
        hf_config = self.info.get_hf_config()

        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        vision_end_token_id = hf_config.vision_end_token_id

        merge_length = image_processor.merge_size**2

        def get_image_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
            do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
            sampled_fps = hf_processor_mm_kwargs.get("fps")
            if is_list_of(sampled_fps, float):
                sampled_fps = sampled_fps[item_idx]
            timestamps = self.info._get_video_second_idx(
1046
1047
                metadata, out_item, do_sample_frames, sampled_fps
            )
1048
1049
1050

            assert len(timestamps) == grid_thw[0], (
                f"The timestamps length({len(timestamps)}) should be equal "
1051
1052
                f"video length ({grid_thw[0]})."
            )
1053
1054

            frames_idx_token = [
1055
                tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
1056
1057
                for curr_time in timestamps
            ]
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
            tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
            per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]

            video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
            if video_pruning_rate is not None and video_pruning_rate > 0.0:
                total_retained = compute_retained_tokens_count(
                    tokens_per_frame,
                    len(frames_idx_token),
                    video_pruning_rate,
                )
                if len(frames_idx_token) == 0:
                    per_frame_token_counts = []
                elif len(frames_idx_token) == 1:
                    per_frame_token_counts = [tokens_per_frame]
                else:
                    first_frame_tokens = tokens_per_frame
                    remaining_tokens = max(total_retained - first_frame_tokens, 0)
                    base = remaining_tokens // (len(frames_idx_token) - 1)
                    remainder = remaining_tokens % (len(frames_idx_token) - 1)
                    per_frame_token_counts = [first_frame_tokens]
                    for frame_idx in range(1, len(frames_idx_token)):
                        extra = base + (1 if (frame_idx - 1) < remainder else 0)
                        per_frame_token_counts.append(extra)

1082
            placeholder = []
1083
1084
1085
1086
1087
            for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
                placeholder.extend(timestamp_tokens)
                tokens_this_frame = per_frame_token_counts[
                    frame_idx if frame_idx < len(per_frame_token_counts) else -1
                ]
1088
1089
                placeholder.extend(
                    [vision_start_token_id]
1090
                    + [video_token_id] * tokens_this_frame
1091
1092
1093
                    + [vision_end_token_id]
                )
            return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_qwen3vl,
            ),
            # NOTE: We match string on purpose since searching sequence of
            # token ids takes more time.
            PromptReplacement(
                modality="video",
                target="<|vision_start|><|video_pad|><|vision_end|>",
                replacement=get_video_replacement_qwen3vl,
            ),
        ]


@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        # the same shape as input_embeds
1120
1121
1122
        "deepstack_input_embeds": 0,
    }
)
1123
1124
1125
1126
1127
class Qwen3LLMModel(Qwen3Model):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        if not get_pp_group().is_first_rank:
            assert self.start_layer >= len(
1128
1129
1130
1131
1132
                vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes
            ), (
                "start_layer should be greater than or equal to "
                "len(deepstack_visual_indexes)"
            )
1133
1134
1135
1136
1137

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1138
1139
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1140
        # args for deepstack
1141
1142
        deepstack_input_embeds: IntermediateTensors | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1143
1144
1145
1146
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1147
                hidden_states = self.embed_input_ids(input_ids)
1148
1149
1150
1151
1152
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
1153

1154
        aux_hidden_states = []
1155
1156
        for layer_idx, layer in islice(
            enumerate(self.layers), self.start_layer, self.end_layer
1157
        ):
1158
1159
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
1160

1161
1162
1163
1164
1165
1166
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

1167
1168
1169
1170
1171
1172
1173
            if deepstack_input_embeds is not None and layer_idx in range(
                0, len(deepstack_input_embeds)
            ):
                hidden_states = (
                    hidden_states
                    + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
                )
1174
1175

        if not get_pp_group().is_last_rank:
1176
1177
1178
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1179
        hidden_states, _ = self.norm(hidden_states, residual)
1180
1181
1182

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
        return hidden_states


class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super(Qwen3ForCausalLM, self).__init__()
        config = vllm_config.model_config.hf_config.text_config
        quant_config = vllm_config.quant_config

        self.config = config

        self.quant_config = quant_config
1195
1196
1197
        self.model = Qwen3LLMModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1198
1199
1200
1201
1202

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
1203
1204
1205
1206
1207
1208
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix="lm_head",
                )
1209
1210
1211
1212
1213
1214
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
1215
1216
            self.model.make_empty_intermediate_tensors
        )
1217
1218


1219
1220
1221
1222
1223
1224
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3VLMultiModalProcessor,
    info=Qwen3VLProcessingInfo,
    dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
1225
1226
1227
1228
1229
1230
    nn.Module,
    SupportsMultiModal,
    SupportsLoRA,
    SupportsPP,
    SupportsMRoPE,
    SupportsEagle3,
1231
    SupportsMultiModalPruning,
1232
):
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
1243
        "qkv": ["qkv"],  # For vision tower's already-packed QKV
1244
    }
1245
1246
1247

    supports_encoder_tp_data = True

1248
1249
1250
1251
1252
1253
    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.visual.": "visual.",
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
1254
1255
        }
    )
1256
1257

    @classmethod
1258
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: Qwen3VLConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1274
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1275
1276
1277
1278
1279
        self.video_pruning_rate = multimodal_config.video_pruning_rate
        self.is_multimodal_pruning_enabled = (
            multimodal_config.is_multimodal_pruning_enabled()
        )

1280
1281
1282
        if not multimodal_config.get_limit_per_prompt(
            "image"
        ) and not multimodal_config.get_limit_per_prompt("video"):
1283
1284
1285
1286
1287
1288
            self.visual = None
        else:
            self.visual = Qwen3_VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
1289
                multimodal_config=multimodal_config,
1290
1291
                prefix=maybe_prefix(prefix, "visual"),
            )
1292

1293
1294
1295
        self.language_model = Qwen3LLMForCausalLM(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
        )
1296
1297

        self.make_empty_intermediate_tensors = (
1298
1299
            self.language_model.make_empty_intermediate_tensors
        )
1300

1301
1302
1303
1304
1305
1306
        self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
        self.deepstack_num_level = (
            len(config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
1307
        # register buffer for deepstack
1308
1309
1310
1311
        if self.use_deepstack and self.visual is not None:
            self.deepstack_input_embeds = [
                torch.zeros(
                    vllm_config.scheduler_config.max_num_batched_tokens,
1312
1313
                    config.text_config.hidden_size,
                )
1314
1315
1316
1317
                for _ in range(self.deepstack_num_level)
            ]
        else:
            self.deepstack_input_embeds = None
1318
1319
        self.visual_dim = config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level
1320

1321
1322
1323
1324
1325
1326
1327
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.language_model.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.language_model.model.layers)
        return (2, num_layers // 2, num_layers - 3)

1328
    def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
1329
        # get deepstack_input_embeds from buffer, and clear the buffer
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
1340
1341
1342
1343
        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
1344
1345
1346
1347
1348
1349
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
1350
1351
1352
1353
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
1354
1355
                deepstack_input_embeds[idx]
            )
1356
1357
1358
1359
1360
1361
1362
1363

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

    def _parse_and_validate_image_input(
1364
        self, **kwargs: object
1365
    ) -> Qwen2_5_VLImageInputs | None:
1366
1367
1368
1369
1370
1371
1372
1373
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
1374
1375
1376
1377
1378
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1379
1380
1381
1382
1383

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
1384
1385
                image_grid_thw=image_grid_thw,
            )
1386
1387

    def _parse_and_validate_video_input(
1388
        self, **kwargs: object
1389
    ) -> Qwen2_5_VLVideoInputs | None:
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
1410
1411
                video_grid_thw=video_grid_thw,
            )
1412
1413

    def _process_image_input(
1414
1415
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1416
1417
1418
1419
1420
1421
1422
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1423
            if self.use_data_parallel:
1424
                return run_dp_sharded_mrope_vision_model(
1425
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1426
                )
1427
            else:
1428
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1429
1430
1431

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1432
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1433
1434
1435
        return image_embeds.split(sizes)

    def _process_video_input(
1436
1437
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1438
1439
1440
1441
1442
1443
1444
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1445
1446
                self.visual.dtype
            )
1447
            if self.use_data_parallel:
1448
                grid_thw_list = grid_thw.tolist()
1449
1450
1451
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1452
            else:
1453
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1454
1455
1456

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1457
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1458
1459
        return video_embeds.split(sizes)

1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
    def _postprocess_image_embeds_evs(
        self,
        image_embeds_split: tuple[torch.Tensor, ...],
        image_input: Qwen2_5_VLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Append mrope positions for each for images.
        This is necessary to recover correct mrope
        positions after video pruning

        Args:
            image_embeds_split: Tuple of image embeddings for
                each image item.
            image_input: Image input data.

        Returns:
            Tuple of image embeddings for each image item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        merge_size = self.visual.spatial_merge_size
        grid_thw = image_input["image_grid_thw"]
        grid_thw_list = grid_thw.tolist()
        image_embeds_out = []
        for emb, size in zip(image_embeds_split, grid_thw_list):
            positions = compute_mrope_for_media(size, merge_size).to(emb.device)
            emb = torch.cat([emb, positions], dim=1)
            image_embeds_out.append(emb)
        image_embeds_split = image_embeds_out
        return tuple(image_embeds_split)

    def _postprocess_video_embeds_evs(
        self,
        video_embeds_split: tuple[torch.Tensor, ...],
        video_input: Qwen2_5_VLVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Prunes video embeddings via Efficient Video Sampling (EVS)
        and then appends mrope positions for each retained embeddings

        Args:
            video_embeds_split: Tuple of video embeddings for each video item.
            video_input: Video input data.

        Returns:
            Tuple of video embeddings for each video item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()
        merge_size = self.visual.spatial_merge_size

        # Cast to long to match the original code
        # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
        second_per_grid_ts = video_input.get("second_per_grid_ts")
        if second_per_grid_ts is None:
            # For Qwen3-VL, second_per_grid_ts might not be available
            # Use default value of 1.0 for each video
            second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
        else:
            second_per_grid_ts = second_per_grid_ts.long()
        tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)

        video_embeds_out = []
        for emb, size, video_second_per_grid_t in zip(
            video_embeds_split, grid_thw_list, second_per_grid_ts
        ):
            # For each video, we compute retention mask using EVS
            retention_mask = compute_retention_mask(
                emb,
                size,
                spatial_merge_size=self.visual.spatial_merge_size,
                q=self.video_pruning_rate,
            )

            # Debug logging for EVS pruning
            logger.debug(
                "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
                "pruning_rate=%.2f, reduction=%.1f%%)",
                emb.shape[0],
                retention_mask.sum().item(),
                size[0],
                size[1],
                size[2],
                self.video_pruning_rate,
                (1 - retention_mask.float().mean().item()) * 100,
            )

            positions = compute_mrope_for_media(
                size,
                merge_size,
                tokens_per_second=tokens_per_second,
                video_second_per_grid=video_second_per_grid_t.item(),
            ).to(emb.device)

            emb = emb[retention_mask]
            positions = positions[retention_mask]
            emb = torch.cat([emb, positions], dim=1)
            video_embeds_out.append(emb)
        return tuple(video_embeds_out)

1563
1564
1565
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}
        for input_key in kwargs:
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
1580
1581
        return mm_input_by_modality

1582
1583
1584
    def iter_mm_grid_hw(
        self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int]]:
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
        """
        Iterate over multimodal features and yield grid information.

        For videos with EVS (Efficient Video Sampling) enabled, this function
        computes the offset based on the pruned token count rather than relying
        on input_tokens.index(), which would fail when tokens are pruned.

        Args:
            input_tokens: List of token IDs in the prompt
            mm_features: List of multimodal feature specifications

        Yields:
            Tuple of (offset, grid_h, grid_w) for each frame/image
        """
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        video_token_id = self.config.video_token_id
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, h // spatial_merge_size, w // spatial_merge_size
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641

                # Check if EVS (Efficient Video Sampling) is enabled
                is_evs_enabled = (
                    hasattr(self, "video_pruning_rate")
                    and self.video_pruning_rate is not None
                    and self.video_pruning_rate > 0.0
                )

                if is_evs_enabled:
                    frame_offsets = self._extract_frame_offsets_from_mask(
                        mm_feature.mm_position, t
                    )
                    if frame_offsets is not None:
                        for rel_offset in frame_offsets:
                            yield offset + rel_offset, llm_grid_h, llm_grid_w
                        continue

                    # If EVS is enabled but mask is missing, this indicates a bug
                    # in the prompt processing pipeline. The is_embed mask should
                    # always be present when video_pruning_rate > 0.
                    raise RuntimeError(
                        f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
                        "but is_embed mask is missing from mm_position. "
                        "This indicates a bug in prompt processing."
                    )
                else:
                    # Non-EVS mode: Use original logic with input_tokens.index()
                    for _ in range(t):
                        offset = input_tokens.index(video_token_id, offset)
                        yield offset, llm_grid_h, llm_grid_w
                        offset += llm_grid_h * llm_grid_w
1642
1643
1644
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    def _get_evs_mask_segments(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[torch.Tensor] | None:
        """Extract contiguous segments from EVS is_embed mask.

        The EVS (Efficient Video Sampling) mask marks which placeholder
        positions should be filled with video embeddings. This method splits
        the mask into contiguous segments, where each segment represents one
        retained frame.

        This is a pure function - it does not modify any state and always
        returns the same output for the same input (idempotent).

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frame segments

        Returns:
            List of tensors, each containing indices for one frame segment,
            or None if EVS is not enabled or validation fails.
        """
        is_embed_mask = getattr(mm_position, "is_embed", None)
        if is_embed_mask is None:
            return None

        # Find all True positions in the mask
        mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
        true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
        if true_indices.numel() == 0:
            return None

        # Split into contiguous segments (where diff > 1 indicates a gap)
        if true_indices.numel() == 1:
            segments = [true_indices]
        else:
            diffs = torch.diff(true_indices)
            split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
            if split_points.numel() == 0:
                segments = [true_indices]
            else:
                segments = torch.tensor_split(
                    true_indices, split_points.add(1).tolist()
                )

        # Validate segment count matches expected frames
        if len(segments) < expected_frames:
            logger.debug(
                "EVS mask segments (%d) do not match expected frames (%d)",
                len(segments),
                expected_frames,
            )
            return None

        return segments[:expected_frames]

    def _extract_frame_offsets_from_mask(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[int] | None:
        """Return relative offsets for each EVS-retained frame.

        The prompt processor stores a boolean mask inside ``mm_position`` that
        marks which placeholder locations should be populated with video
        embeddings. By splitting that mask into contiguous runs we can recover
        the start of every retained frame without probing ``input_tokens``.

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frames

        Returns:
            List of starting offsets (relative to mm_position) for each frame,
            or None if EVS is not enabled.
        """
        segments = self._get_evs_mask_segments(mm_position, expected_frames)
        if segments is None:
            return None

        return [int(segment[0].item()) for segment in segments]

    def _get_actual_frame_token_counts(
        self, mm_position: PlaceholderRange, expected_frames: int
    ) -> list[int] | None:
        """Return actual token count for each EVS-retained frame.

        This function calculates the actual number of tokens per frame by
        analyzing the is_embed mask, accounting for EVS pruning. Each frame
        may have a different token count due to content-aware pruning.

        Args:
            mm_position: MultiModal position containing the is_embed mask
            expected_frames: Expected number of frames

        Returns:
            List of token counts for each frame, or None if EVS is not enabled.
        """
        segments = self._get_evs_mask_segments(mm_position, expected_frames)
        if segments is None:
            return None

        return [len(seg) for seg in segments]

    def recompute_mrope_positions(
        self,
        input_ids: list[int],
        multimodal_embeddings: tuple[torch.Tensor, ...],
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
            input_ids: (N,) All input tokens of the prompt (Containing
                entire sequence).
            multimodal_embeddings: Tuple of multimodal embeddings.
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
        image_token_id = self.config.image_token_id
        video_token_id = self.config.video_token_id
        vision_start_token_id = self.config.vision_start_token_id

        # Device
        device = (
            multimodal_embeddings[0].device
            if len(multimodal_embeddings)
            else mrope_positions.device
        )

        # Tensors
        input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

        mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
        mm_embeddings_pos = [
            mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
        ]

        positions, mrope_positions_delta = recompute_mrope_positions(
            input_ids_t,
            mm_embeddings_pos,
            mrope_positions,
            num_computed_tokens,
            vision_start_token_id,
            image_token_id,
            video_token_id,
        )

        return tuple(mm_embeddings_out), positions, mrope_positions_delta

1803
    def get_mrope_input_positions(
1804
        self,
1805
        input_tokens: list[int],
1806
        mm_features: list[MultiModalFeatureSpec],
1807
    ) -> tuple[torch.Tensor, int]:
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
        # Pre-collect actual frame token counts for EVS mode
        frame_token_counts_map = {}
        for mm_feature in mm_features:
            if mm_feature.modality == "video":
                is_evs_enabled = (
                    hasattr(self, "video_pruning_rate")
                    and self.video_pruning_rate is not None
                    and self.video_pruning_rate > 0.0
                )
                if is_evs_enabled:
                    t = mm_feature.data["video_grid_thw"].data.tolist()[0]
                    token_counts = self._get_actual_frame_token_counts(
                        mm_feature.mm_position, t
                    )
                    assert token_counts is not None, (
                        "EVS enabled but failed to extract frame token counts "
                        "from is_embed mask"
                    )
                    frame_token_counts_map[mm_feature.mm_position.offset] = token_counts

1828
        llm_pos_ids_list = []
1829
        st = 0
1830
1831
        frame_counts_idx = {}

1832
1833
1834
1835
        for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
            input_tokens, mm_features
        ):
            text_len = offset - st
1836
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867

            # Determine actual token count for this frame
            base_offset = None
            for feat_offset in frame_token_counts_map:
                if offset >= feat_offset:
                    base_offset = feat_offset

            if base_offset is not None:
                # EVS mode: use actual token count from is_embed mask
                assert base_offset in frame_token_counts_map, (
                    f"Found base_offset {base_offset} but not in frame_token_counts_map"
                )

                if base_offset not in frame_counts_idx:
                    frame_counts_idx[base_offset] = 0

                counts = frame_token_counts_map[base_offset]
                idx = frame_counts_idx[base_offset]

                assert idx < len(counts), (
                    f"EVS frame index {idx} out of range (total frames: {len(counts)})"
                )

                actual_frame_tokens = counts[idx]
                frame_counts_idx[base_offset] += 1
            else:
                # Non-EVS mode (or image): use theoretical grid size
                actual_frame_tokens = llm_grid_h * llm_grid_w

            # Add text segment
            text_positions = (
1868
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1869
            )
1870
1871
            llm_pos_ids_list.append(text_positions)
            st_idx += text_len
1872

1873
            # Add frame segment with actual token count (not theoretical)
1874
            grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
1875
1876
1877
1878
1879
1880
            # Only take the first actual_frame_tokens positions
            frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
            llm_pos_ids_list.append(frame_positions)

            # Update st using actual token count
            st = offset + actual_frame_tokens
1881

1882
        # Handle final text segment
1883
1884
1885
        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
1886
            final_text_positions = (
1887
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1888
            )
1889
            llm_pos_ids_list.append(final_text_positions)
1890

1891
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
1892
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1893

1894
        return torch.from_numpy(llm_positions), mrope_position_delta
1895

1896
1897
1898
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1899
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
1900
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
1913
                image_embeddings = self._process_image_input(multimodal_input)
1914
1915
1916
1917
                if self.is_multimodal_pruning_enabled:
                    image_embeddings = self._postprocess_image_embeds_evs(
                        image_embeddings, multimodal_input
                    )
1918
                multimodal_embeddings += tuple(image_embeddings)
1919
1920
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
1921
1922
1923
1924
                if self.is_multimodal_pruning_enabled:
                    video_embeddings = self._postprocess_video_embeds_evs(
                        video_embeddings, multimodal_input
                    )
1925
                multimodal_embeddings += tuple(video_embeddings)
1926
1927
1928
        return multimodal_embeddings

    def _compute_deepstack_embeds(
1929
1930
1931
1932
1933
1934
        self,
        inputs_embeds: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings,
        is_multimodal: torch.Tensor,
    ) -> tuple[torch.Tensor, MultiModalEmbeddings]:
        visual_lens = [len(x) for x in multimodal_embeddings]
1935
1936
        multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

1937
1938
1939
1940
1941
1942
1943
1944
        (
            multimodal_embeddings_main,
            multimodal_embeddings_multiscale,
        ) = torch.split(
            multimodal_embeddings_cat,
            [self.visual_dim, self.multiscale_dim],
            dim=-1,
        )
1945

1946
1947
1948
        multimodal_embeddings = torch.split(
            multimodal_embeddings_main, visual_lens, dim=0
        )
1949
        multimodal_embeddings_multiscale = torch.split(
1950
1951
            multimodal_embeddings_multiscale, visual_lens, dim=0
        )
1952
1953

        deepstack_input_embeds = inputs_embeds.new_zeros(
1954
1955
            inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
        )
1956

1957
1958
1959
1960
        deepstack_input_embeds = _merge_multimodal_embeddings(
            inputs_embeds=deepstack_input_embeds,
            multimodal_embeddings=multimodal_embeddings_multiscale,
            is_multimodal=is_multimodal,
1961
1962
        )
        deepstack_input_embeds = deepstack_input_embeds.view(
1963
1964
            inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
        )
1965
        deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
1966

1967
1968
        return deepstack_input_embeds, multimodal_embeddings

1969
    def embed_input_ids(
1970
1971
        self,
        input_ids: torch.Tensor,
1972
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1973
        *,
1974
        is_multimodal: torch.Tensor | None = None,
1975
        handle_oov_mm_token: bool = False,
1976
    ) -> torch.Tensor:
1977
        inputs_embeds = self._embed_text_input_ids(
1978
            input_ids,
1979
            self.language_model.embed_input_ids,
1980
1981
1982
1983
1984
1985
1986
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

1987
        is_multimodal = _require_is_multimodal(is_multimodal)
1988
1989

        if self.use_deepstack:
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=inputs_embeds,
                multimodal_embeddings=multimodal_embeddings,
                is_multimodal=is_multimodal,
            )
        else:
            deepstack_input_embeds = None

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        if deepstack_input_embeds is not None:
2008
2009
2010
2011
2012
2013
2014
2015
            self._set_deepstack_input_embeds(deepstack_input_embeds)

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2016
2017
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
2018
        **kwargs: object,
2019
    ) -> torch.Tensor | IntermediateTensors:
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
        """Run forward pass for Qwen3VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen3VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
            intermediate_tensors: Intermediate tensors from previous pipeline
                stages.
            inputs_embeds: Pre-computed input embeddings.
            **kwargs: Additional keyword arguments including:
                - pixel_values: Pixel values to be fed to a model.
                    `None` if no images are passed.
                - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
                    LLM. `None` if no images are passed.
                - pixel_values_videos: Pixel values of videos to be fed to a
                    model. `None` if no videos are passed.
                - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
                    LLM. `None` if no videos are passed.
2042
2043
2044
2045
2046
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

2047
2048
2049
2050
2051
        if (
            self.use_deepstack
            and inputs_embeds is not None
            and get_pp_group().is_first_rank
        ):
2052
            deepstack_input_embeds = self._get_deepstack_input_embeds(
2053
2054
                inputs_embeds.size(0)
            )
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
2075
    ) -> torch.Tensor | None:
2076
        return self.language_model.compute_logits(hidden_states)
2077

2078
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
2079
2080
2081
2082
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
2083
2084
2085
2086
2087
2088
2089
2090
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
2091
            connector=["visual.merger", "visual.deepstack_merger_list"],
2092
            tower_model="visual.",
2093
        )
2094

2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

2114
2115
2116
2117
2118
2119
2120
    @classmethod
    def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
        """
        Return the language model spec:
        (language model class, language model attr)
        """
        return Qwen3LLMForCausalLM, "language_model"