qwen3_vl.py 103 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
26

27
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
28
from functools import lru_cache, partial
29
from itertools import islice
30
from typing import Any
31
32
33
34
35

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
36
from transformers import BatchFeature
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
38
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
39
40
41
    smart_resize as image_smart_resize,
)
from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor
42
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
43
44
45
    Qwen3VLConfig,
    Qwen3VLVisionConfig,
)
46
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
47
48
    smart_resize as video_smart_resize,
)
49
50
51
from transformers.video_utils import VideoMetadata

from vllm.compilation.decorators import support_torch_compile
52
from vllm.config import VllmConfig
53
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
54
from vllm.distributed import get_pp_group, parallel_state
55
from vllm.inputs import MultiModalDataDict
56
57
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
58
59
60
from vllm.model_executor.layers.attention.mm_encoder_attention import (
    MMEncoderAttention,
)
61
from vllm.model_executor.layers.conv import Conv3dLayer
62
63
64
65
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
66
67
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
68
from vllm.model_executor.layers.rotary_embedding import get_rope
69
70
71
72
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
73
74
75
76
77
78
from vllm.multimodal.evs import (
    compute_mrope_for_media,
    compute_retained_tokens_count,
    compute_retention_mask,
    recompute_mrope_positions,
)
79
from vllm.multimodal.inputs import (
80
    MultiModalFeatureSpec,
81
    MultiModalFieldConfig,
82
    MultiModalFieldElem,
83
84
    MultiModalKwargsItem,
    MultiModalKwargsItems,
85
    PlaceholderRange,
86
87
    VideoItem,
)
88
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
89
from vllm.multimodal.processing import (
90
    BaseDummyInputsBuilder,
91
92
93
94
95
    BaseMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
96
from vllm.sequence import IntermediateTensors
97
98
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tokenizers.registry import cached_tokenizer_from_config
99
from vllm.triton_utils import HAS_TRITON, tl, triton
100
from vllm.utils.collection_utils import is_list_of
101
from vllm.utils.math_utils import round_up
102

103
104
from .interfaces import (
    MultiModalEmbeddings,
105
    SupportsEagle,
106
    SupportsEagle3,
107
    SupportsEncoderCudaGraph,
108
    SupportsLoRA,
109
    SupportsMRoPE,
110
    SupportsMultiModal,
111
    SupportsMultiModalPruning,
112
    SupportsPP,
113
    _require_is_multimodal,
114
115
116
117
118
119
120
121
122
123
)
from .qwen2_5_vl import (
    Qwen2_5_VisionAttention,
    Qwen2_5_VLImageEmbeddingInputs,
    Qwen2_5_VLImageInputs,
    Qwen2_5_VLImagePixelInputs,
    Qwen2_5_VLVideoEmbeddingInputs,
    Qwen2_5_VLVideoInputs,
    Qwen2_5_VLVideoPixelInputs,
)
124
125
126
127
128
from .qwen2_vl import (
    Qwen2VLMultiModalDataParser,
    Qwen2VLProcessingInfo,
    _create_qwen2vl_field_factory,
)
129
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
130
131
132
133
134
135
136
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    _merge_multimodal_embeddings,
    maybe_prefix,
)
137
138
from .vision import (
    get_vit_attn_backend,
139
    is_vit_use_data_parallel,
140
141
    run_dp_sharded_mrope_vision_model,
)
142
143
144

logger = init_logger(__name__)

145
146
147
# We use 2048 dummy video frames that would generate vision embeddings
# of the maximum size.
DUMMY_VIDEO_NUM_FRAMES = 2048
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
# ---------------------------------------------------------------------------
# Triton kernel: fused bilinear position-embedding interpolation
# ---------------------------------------------------------------------------
# Replaces many small eager-mode CUDA kernels with a single launch.
# The spatial-merge reorder is baked into the index math so the output
# is ready to be added to the patch embeddings directly.
# ---------------------------------------------------------------------------

if HAS_TRITON:

    @triton.jit
    def _bilinear_pos_embed_kernel(
        embed_ptr,
        output_ptr,
        H,
        W,
        h_scale,
        w_scale,
        NUM_GRID: tl.constexpr,
        M_SIZE: tl.constexpr,
        HIDDEN_DIM: tl.constexpr,
        BLOCK_D: tl.constexpr,
    ):
        """Fused bilinear pos-embed interpolation with spatial-merge reorder."""
        pid = tl.program_id(0)
        total_spatial = H * W
        spatial_idx = pid % total_spatial

        num_blocks_w = W // M_SIZE
        block_idx = spatial_idx // (M_SIZE * M_SIZE)
        local_idx = spatial_idx % (M_SIZE * M_SIZE)
        br = block_idx // num_blocks_w
        bc = block_idx % num_blocks_w
        lr = local_idx // M_SIZE
        lc = local_idx % M_SIZE
        row = br * M_SIZE + lr
        col = bc * M_SIZE + lc

        h_frac = row.to(tl.float32) * h_scale
        w_frac = col.to(tl.float32) * w_scale

        hf = tl.math.floor(h_frac).to(tl.int32)
        wf = tl.math.floor(w_frac).to(tl.int32)
        hc = tl.minimum(hf + 1, NUM_GRID - 1)
        wc = tl.minimum(wf + 1, NUM_GRID - 1)

        dh = h_frac - hf.to(tl.float32)
        dw = w_frac - wf.to(tl.float32)
        w11 = dh * dw
        w10 = dh - w11
        w01 = dw - w11
        w00 = 1.0 - dh - w01

        off00 = (hf * NUM_GRID + wf) * HIDDEN_DIM
        off01 = (hf * NUM_GRID + wc) * HIDDEN_DIM
        off10 = (hc * NUM_GRID + wf) * HIDDEN_DIM
        off11 = (hc * NUM_GRID + wc) * HIDDEN_DIM
        out_off = pid * HIDDEN_DIM

        # Cast weights to output dtype so the multiply-accumulate stays
        # in the same precision as the native PyTorch implementation.
        out_dtype = output_ptr.dtype.element_ty
        w00_c = w00.to(out_dtype)
        w01_c = w01.to(out_dtype)
        w10_c = w10.to(out_dtype)
        w11_c = w11.to(out_dtype)

        for d in tl.range(0, HIDDEN_DIM, BLOCK_D):
            cols = d + tl.arange(0, BLOCK_D)
            mask = cols < HIDDEN_DIM

            e00 = tl.load(embed_ptr + off00 + cols, mask=mask)
            e01 = tl.load(embed_ptr + off01 + cols, mask=mask)
            e10 = tl.load(embed_ptr + off10 + cols, mask=mask)
            e11 = tl.load(embed_ptr + off11 + cols, mask=mask)

            val = w00_c * e00 + w01_c * e01 + w10_c * e10 + w11_c * e11

            tl.store(output_ptr + out_off + cols, val, mask=mask)

    def triton_pos_embed_interpolate(
        embed_weight: torch.Tensor,
        t: int,
        h: int,
        w: int,
        num_grid_per_side: int,
        m_size: int,
        dtype: torch.dtype,
    ) -> torch.Tensor:
        """Launch the fused Triton kernel for one (t,h,w) grid.

        Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
        bilinearly-interpolated position embeddings in spatial-merge order.
        """
        assert h % m_size == 0 and w % m_size == 0, (
            f"h={h} and w={w} must be divisible by m_size={m_size}"
        )
        hidden_dim = embed_weight.shape[1]
        total_out = t * h * w
        output = torch.empty(
            total_out,
            hidden_dim,
            device=embed_weight.device,
            dtype=dtype,
        )

        h_scale = float(num_grid_per_side - 1) / float(h - 1) if h > 1 else 0.0
        w_scale = float(num_grid_per_side - 1) / float(w - 1) if w > 1 else 0.0

        BLOCK_D = triton.next_power_of_2(hidden_dim)

        _bilinear_pos_embed_kernel[(total_out,)](
            embed_weight,
            output,
            h,
            w,
            h_scale,
            w_scale,
            num_grid_per_side,
            m_size,
            hidden_dim,
            BLOCK_D,
        )
        return output


def pos_embed_interpolate_native(
    embed_weight: torch.Tensor,
    t: int,
    h: int,
    w: int,
    num_grid_per_side: int,
    m_size: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Eager PyTorch bilinear position-embedding interpolation.

    Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
    bilinearly-interpolated position embeddings in spatial-merge order.
    """
    assert h % m_size == 0 and w % m_size == 0, (
        f"h={h} and w={w} must be divisible by m_size={m_size}"
    )
    hidden_dim = embed_weight.shape[1]
    device = embed_weight.device

    h_idxs = torch.linspace(
        0,
        num_grid_per_side - 1,
        h,
        dtype=torch.float32,
        device=device,
    )
    w_idxs = torch.linspace(
        0,
        num_grid_per_side - 1,
        w,
        dtype=torch.float32,
        device=device,
    )

    h_floor = h_idxs.to(torch.long)
    w_floor = w_idxs.to(torch.long)
    h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
    w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

    dh = h_idxs - h_floor
    dw = w_idxs - w_floor

    dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
    h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
    h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")

    w11 = dh_grid * dw_grid
    w10 = dh_grid - w11
    w01 = dw_grid - w11
    w00 = 1 - dh_grid - w01

    h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
    w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
    h_grid_idx = h_grid * num_grid_per_side

    indices = (h_grid_idx + w_grid).reshape(4, -1)
    weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
    weights = weights.to(dtype=dtype)

    embeds = embed_weight[indices]
    embeds *= weights
    combined = embeds.sum(dim=0)

    combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
    combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
    repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
    return repeated.to(dtype=dtype)

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

class Qwen3_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        hidden_size: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
359
360
        self.proj = Conv3dLayer(
            in_channels,
361
            hidden_size,
362
363
            kernel_size=kernel_size,
            stride=kernel_size,
364
365
            bias=True,
        )
366
367

    def forward(self, x: torch.Tensor) -> torch.Tensor:
368
369
370
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.hidden_size)
371
372
373
374
        return x


class Qwen3_VisionMLP(nn.Module):
375
376
377
378
379
380
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
381
        quant_config: QuantizationConfig | None = None,
382
383
        prefix: str = "",
    ):
384
        super().__init__()
385
        use_data_parallel = is_vit_use_data_parallel()
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        self.linear_fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
        self.linear_fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        self.act_fn = act_fn

    def forward(self, x: torch.Tensor):
        mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return mlp_output


class Qwen3_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
418
419
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
420
421
422
423
424
425
426
        prefix: str = "",
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
427
428
429
430
431
432
        self.attn = Qwen2_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
433
434
435
436
437
438
439
440
441
        )
        self.mlp = Qwen3_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_fn=act_fn,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
442
443

    def forward(
444
445
446
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
447
448
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
449
        max_seqlen: torch.Tensor,  # Only used for Flash Attention
450
        sequence_lengths: torch.Tensor,  # Only used for FlashInfer CuDNN backend
451
    ) -> torch.Tensor:
452
453
454
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
455
456
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
457
            max_seqlen=max_seqlen,
458
            sequence_lengths=sequence_lengths,
459
        )
460
461
462
463
464
465
466
467
468
469

        x = x + self.mlp(self.norm2(x))
        return x


class Qwen3_VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
470
        norm_layer: Callable[[int], nn.Module] | None = None,
471
472
        spatial_merge_size: int = 2,
        use_postshuffle_norm: bool = False,
473
        quant_config: QuantizationConfig | None = None,
474
475
476
        prefix: str = "",
    ) -> None:
        super().__init__()
477
        use_data_parallel = is_vit_use_data_parallel()
478
479
480
481
482
483
484
485
        self.hidden_size = context_dim * (spatial_merge_size**2)

        self.use_postshuffle_norm = use_postshuffle_norm
        if self.use_postshuffle_norm:
            context_dim = self.hidden_size

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
486
        self.norm = norm_layer(context_dim)
487
488
489
490
491
492
493
494
        self.linear_fc1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
495
        self.act_fn = nn.GELU()
496
497
498
499
500
501
502
503
        self.linear_fc2 = RowParallelLinear(
            self.hidden_size,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_postshuffle_norm:
            x = self.norm(x.view(-1, self.hidden_size))
        else:
            x = self.norm(x).view(-1, self.hidden_size)

        x_parallel, _ = self.linear_fc1(x)
        x_parallel = self.act_fn(x_parallel)
        out, _ = self.linear_fc2(x_parallel)
        return out


class Qwen3_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen3VLVisionConfig,
        norm_eps: float = 1e-6,
522
        quant_config: QuantizationConfig | None = None,
523
524
525
526
527
528
529
530
531
532
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.num_position_embeddings = vision_config.num_position_embeddings
        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.spatial_merge_unit = self.spatial_merge_size**2
        self.temporal_patch_size = vision_config.temporal_patch_size
zxy's avatar
zxy committed
533
534
535
536
537
        self.deepstack_visual_indexes = (
            vision_config.deepstack_visual_indexes
            if hasattr(vision_config, "deepstack_visual_indexes")
            else []
        )
538
        self.num_grid_per_side = int(self.num_position_embeddings**0.5)
539

540
541
542
543
544
545
546
        use_data_parallel = is_vit_use_data_parallel()
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )

547
548
        # NOTE: This is used for creating empty tensor for all_gather for
        # DP ViT. Here out_hidden_size is enlarged due to deepstack
549
550
551
        self.out_hidden_size = vision_config.out_hidden_size * (
            1 + len(self.deepstack_visual_indexes)
        )
552
553
554
555
556
557
558
559

        self.patch_embed = Qwen3_VisionPatchEmbed(
            patch_size=self.patch_size,
            temporal_patch_size=self.temporal_patch_size,
            in_channels=vision_config.in_channels,
            hidden_size=self.hidden_size,
        )

560
        self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
561
562
563

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
564
565
566
567
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
568
            rope_parameters={"partial_rotary_factor": 0.5},
569
        )
570
571
572
573
574
575
576
577
578
579

        self.merger = Qwen3_VisionPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=self.hidden_size,
            norm_layer=norm_layer,
            spatial_merge_size=self.spatial_merge_size,
            quant_config=quant_config,
            prefix=f"{prefix}.merger",
        )

580
581
582
583
584
585
586
587
588
589
590
591
592
593
        self.deepstack_merger_list = nn.ModuleList(
            [
                Qwen3_VisionPatchMerger(
                    d_model=vision_config.out_hidden_size,
                    context_dim=self.hidden_size,
                    spatial_merge_size=self.spatial_merge_size,
                    use_postshuffle_norm=True,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
                )
                for layer_idx in range(len(self.deepstack_visual_indexes))
            ]
        )
594
595

        self.attn_backend = get_vit_attn_backend(
596
597
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
598
        )
599

600
601
602
603
604
605
606
607
608
609
610
611
612
613
        self.blocks = nn.ModuleList(
            [
                Qwen3_VisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.intermediate_size,
                    act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(vision_config.depth)
            ]
        )
614
615
616
617
618
619
620
621
622

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    @staticmethod
    @lru_cache(maxsize=1024)
    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
        h_div = h // spatial_merge_size
        w_div = w // spatial_merge_size
        hpos_ids = hpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
        wpos_ids = wpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))

650
651
    def rot_pos_emb(self, grid_thw: list[list[int]]):
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
652
653
654
655
656
657
        pos_ids = [
            self.rot_pos_ids(h, w, self.spatial_merge_size)
            if t == 1
            else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
            for t, h, w in grid_thw
        ]
658
        pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
659
660
661
662

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

663
664
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
665
666

        return cos_combined, sin_combined
667

668
    def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
669
670
671
        interpolate_fn = (
            triton_pos_embed_interpolate if HAS_TRITON else pos_embed_interpolate_native
        )
672
        outputs = []
673
        for t, h, w in grid_thw:
674
675
676
677
678
679
680
681
682
683
            outputs.append(
                interpolate_fn(
                    self.pos_embed.weight,
                    t,
                    h,
                    w,
                    self.num_grid_per_side,
                    self.spatial_merge_size,
                    self.dtype,
                )
684
            )
685
        return torch.cat(outputs, dim=0)
686

687
    def prepare_encoder_metadata(
688
        self,
689
690
691
692
693
694
695
        grid_thw_list: list[list[int]],
        *,
        max_batch_size: int | None = None,
        max_seqlen_override: int | None = None,
        device: torch.device | None = None,
    ) -> dict[str, torch.Tensor | None]:
        """Compute encoder metadata from grid_thw_list.
696

697
698
        Shared by the eager forward path, CUDA graph capture, and
        CUDA graph replay to avoid duplicated implementation.
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
        Args:
            grid_thw_list: Grid configurations as list of [t, h, w].
            max_batch_size: If set, pad cu_seqlens to this size
                (needed for CUDA graph capture/replay).
            max_seqlen_override: If set, use this value for max_seqlen
                instead of computing from cu_seqlens (needed for CUDA
                graph capture to cover worst-case replay scenarios).
            device: Device to place tensors on. Defaults to self.device.
        """
        if device is None:
            device = self.device

        metadata: dict[str, torch.Tensor | None] = {}

        # Positional embeddings
        metadata["pos_embeds"] = self.fast_pos_embed_interpolate(grid_thw_list)
        rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw_list)
        metadata["rotary_pos_emb_cos"] = rotary_cos
        metadata["rotary_pos_emb_sin"] = rotary_sin

        # cu_seqlens from grid_thw
        grid_thw_np = np.array(grid_thw_list, dtype=np.int32)
        patches_per_frame = grid_thw_np[:, 1] * grid_thw_np[:, 2]
        cu_seqlens = np.repeat(patches_per_frame, grid_thw_np[:, 0]).cumsum(
            dtype=np.int32
725
726
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745

        # Pad cu_seqlens if max_batch_size specified
        if max_batch_size is not None:
            num_seqs = len(cu_seqlens) - 1
            if num_seqs < max_batch_size:
                cu_seqlens = np.concatenate(
                    [
                        cu_seqlens,
                        np.full(
                            max_batch_size - num_seqs,
                            cu_seqlens[-1],
                            dtype=np.int32,
                        ),
                    ]
                )

        # sequence_lengths (backend-specific)
        metadata["sequence_lengths"] = MMEncoderAttention.maybe_compute_seq_lens(
            self.attn_backend, cu_seqlens, device
746
        )
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761

        # max_seqlen
        if max_seqlen_override is not None:
            max_seqlen_val = max_seqlen_override
        else:
            max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
                self.attn_backend, cu_seqlens
            )
        # Keep max_seqlen on CPU: attention wrappers call .item() on it,
        # and having it on GPU would capture a wasteful D2H copy in CUDA
        # graphs without changing behavior (the scalar is baked at capture).
        metadata["max_seqlen"] = torch.tensor(max_seqlen_val, dtype=torch.int32)

        # Recompute cu_seqlens (backend-specific transformation)
        metadata["cu_seqlens"] = MMEncoderAttention.maybe_recompute_cu_seqlens(
762
763
764
765
            self.attn_backend,
            cu_seqlens,
            self.hidden_size,
            self.tp_size,
766
            device,
767
        )
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789

        return metadata

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor | list[list[int]],
        *,
        encoder_metadata: dict[str, torch.Tensor] | None = None,
    ) -> torch.Tensor:
        hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
        hidden_states = self.patch_embed(hidden_states)

        if encoder_metadata is None:
            if isinstance(grid_thw, list):
                grid_thw_list = grid_thw
            else:
                grid_thw_list = grid_thw.tolist()
            encoder_metadata = self.prepare_encoder_metadata(grid_thw_list)

        pos_embeds = encoder_metadata["pos_embeds"]
        hidden_states = hidden_states + pos_embeds
790
791
792
793
        hidden_states = hidden_states.unsqueeze(1)

        deepstack_feature_lists = []
        for layer_num, blk in enumerate(self.blocks):
794
795
            hidden_states = blk(
                hidden_states,
796
797
798
799
800
                cu_seqlens=encoder_metadata["cu_seqlens"],
                rotary_pos_emb_cos=encoder_metadata["rotary_pos_emb_cos"],
                rotary_pos_emb_sin=encoder_metadata["rotary_pos_emb_sin"],
                max_seqlen=encoder_metadata["max_seqlen"],
                sequence_lengths=encoder_metadata.get("sequence_lengths"),
801
            )
802
            if layer_num in self.deepstack_visual_indexes:
803
804
805
806
                deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
                deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](
                    hidden_states
                )
807
808
809
                deepstack_feature_lists.append(deepstack_feature)
        hidden_states = self.merger(hidden_states)
        hidden_states = torch.cat(
810
811
            [hidden_states] + deepstack_feature_lists, dim=1
        )  # [seq_len, hidden_size * (1 + depth_of_deepstack)]
812
813
        return hidden_states

814
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
815
816
817
818
819
820
821
822
823
824
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
825
            for param_name, weight_name, shard_id in stacked_params_mapping:
826
827
828
829
830
831
832
833
834
835
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
836
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Qwen3VLConfig)

    def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
        return self.ctx.get_hf_processor(
            Qwen3VLProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )

853
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
854
855
856
857
858
        return self.get_hf_processor(**kwargs).image_processor

    def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor

859
860
861
862
863
864
865
    def get_data_parser(self):
        return Qwen2VLMultiModalDataParser(
            self.get_hf_config().vision_config.spatial_merge_size,
            video_needs_metadata=True,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

866
867
868
869
870
871
872
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 2,
        do_resize: bool = True,
873
874
        image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor,
        mm_kwargs: Mapping[str, object],
875
    ) -> tuple[ImageSize, int]:
876
877
        is_video = isinstance(image_processor, Qwen3VLVideoProcessor)

878
879
880
881
882
883
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size

884
        mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
885
886
887
888
889
890
891
        size = image_processor.size
        if override_size := mm_kwargs.get("size"):
            size = size | override_size
        if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
            size = size | {"shortest_edge": override_min_pixels}
        if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
            size = size | {"longest_edge": override_max_pixels}
892

893
        if do_resize:
894
895
896
897
            if is_video:
                smart_resize = video_smart_resize
                extra_kwargs = {
                    "num_frames": num_frames,
898
                    "temporal_factor": temporal_patch_size,
899
900
901
902
                }
            else:
                smart_resize = image_smart_resize
                extra_kwargs = {}
903

904
905
906
907
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
908
909
                min_pixels=size["shortest_edge"],
                max_pixels=size["longest_edge"],
910
                **extra_kwargs,
911
            )
912
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
913
        else:
914
            preprocessed_size = ImageSize(width=image_width, height=image_height)
915

916
        padded_num_frames = round_up(num_frames, temporal_patch_size)
917
918
919
920
921
922
923
924
925
926

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

927
928
929
930
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int:
        return super()._get_max_video_frames(
            max_tokens, start_num_frames=start_num_frames
        )
931
932
933
934
935
936
937

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        return super().get_num_frames_with_most_features(
938
            seq_len, mm_counts, max_frames_per_video=DUMMY_VIDEO_NUM_FRAMES
939
        )
940
941
942
943
944
945

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
946
        video_processor = self.get_video_processor()
947
948
949
950
951
952
953

        mm_kwargs = self.ctx.get_merged_mm_kwargs({})
        video_size = mm_kwargs.get("size", video_processor.size)
        temporal_patch_size = mm_kwargs.get(
            "temporal_patch_size", video_processor.temporal_patch_size
        )

954
955
        # video_max_pixels contains the temporal compression factor,
        # so we divide by 2 to get the maximum number of image pixels.
956
        video_max_pixels = video_size["longest_edge"]
957
        target_width, target_height = self.get_image_size_with_most_features(
958
            max_pixels=video_max_pixels // temporal_patch_size
959
        )
960
        num_video_soft_tokens = self.get_num_video_tokens(
961
962
            image_width=target_width,
            image_height=target_height,
963
            num_frames=2,
964
965
            image_processor=video_processor,
            mm_kwargs={},
966
        )
967
        return num_video_soft_tokens
968

969
970
971
    def _calculate_timestamps(
        self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
    ):
972
973
974
975
        if not isinstance(indices, list):
            indices = indices.tolist()
        if len(indices) % merge_size != 0:
            # don't update metadata's frames_indices directly
976
            indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size)
977
        timestamps = [idx / video_fps for idx in indices]
978
979
980
981
        timestamps = [
            (timestamps[i] + timestamps[i + merge_size - 1]) / 2
            for i in range(0, len(timestamps), merge_size)
        ]
982
983
984
        return timestamps

    def _get_video_second_idx(
985
986
        self,
        metadata: dict[str, Any],
987
988
        do_sample_frames: bool | None = None,
        sampled_fps: float | None = None,
989
        sampled_num_frames: int | None = None,
990
    ) -> list[int]:
991
        video_processor = self.get_video_processor()
992
        temporal_patch_size = video_processor.temporal_patch_size
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
        indices = metadata["frames_indices"]

        # metadata["fps"] refers to the true fps of the input video.
        video_fps = metadata["fps"]
        if do_sample_frames is None:
            do_sample_frames = metadata.get("do_sample_frames", False)

        # If video frames are sampled in HF processor (instead of vLLM
        # video loader), we need to re-calculate the indices from original
        # metadata.
        if do_sample_frames:
            total_num_frames = metadata["total_num_frames"]
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017

            # When num_frames is explicitly provided, use it directly
            # instead of computing from fps. This mirrors the behavior of
            # HF's Qwen3VLVideoProcessor.sample_frames where num_frames
            # and fps are mutually exclusive.
            if sampled_num_frames is not None:
                num_frames = sampled_num_frames
            else:
                # here video_fps is the fps of the sampled video, and
                # metadata["fps"] refers to the fps of the original video.
                sampled_fps = sampled_fps if sampled_fps else video_processor.fps
                num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)

1018
            num_frames = min(
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
                min(
                    max(num_frames, video_processor.min_frames),
                    video_processor.max_frames,
                ),
                total_num_frames,
            )
            indices = (
                np.linspace(0, total_num_frames - 1, num_frames)
                .round()
                .astype(int)
                .tolist()
            )
1031
        timestamps = self._calculate_timestamps(indices, video_fps, temporal_patch_size)
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        return timestamps


class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_token = "<|vision_start|><|image_pad|><|vision_end|>"
        video_token = "<|vision_start|><|video_pad|><|vision_end|>"

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1049
        mm_options: Mapping[str, BaseDummyOptions],
1050
1051
1052
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
1053
1054
        image_overrides = mm_options.get("image")
        video_overrides = mm_options.get("video")
1055

1056
        target_image_width, target_image_height = (
1057
            self.info.get_image_size_with_most_features()
1058
        )
1059

1060
1061
        # treat videos as special images
        target_num_frames = 2
1062
1063
1064
1065
1066
1067
1068
1069
        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            num_frames_override = video_overrides.num_frames
            if num_frames_override:
                if num_frames_override > target_num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
1070
1071
1072
                        num_frames_override,
                        target_num_frames,
                    )
1073
1074
1075
                if num_frames_override < 2:
                    logger.warning(
                        "video.num_frames override (%d) cannot be less "
1076
1077
1078
                        "than 2, will be ignored",
                        num_frames_override,
                    )
1079
1080
1081
                target_num_frames = min(target_num_frames, num_frames_override)
        target_num_frames = max(target_num_frames, 2)

1082
1083
1084
1085
1086
1087
1088
1089
        video_processor = self.info.get_video_processor()

        mm_kwargs = self.info.ctx.get_merged_mm_kwargs({})
        video_size = mm_kwargs.get("size", video_processor.size)
        temporal_patch_size = mm_kwargs.get(
            "temporal_patch_size", video_processor.temporal_patch_size
        )

1090
1091
        # video_max_pixels contains the temporal compression factor,
        # so we divide by 2 to get the maximum number of image pixels.
1092
        video_max_pixels = video_size["longest_edge"]
1093
1094
        target_video_width, target_video_height = (
            self.info.get_image_size_with_most_features(
1095
                max_pixels=video_max_pixels // temporal_patch_size
1096
1097
            )
        )
1098
        target_video_size, _ = self.info._get_vision_info(
1099
1100
            image_width=target_video_width,
            image_height=target_video_height,
1101
            num_frames=target_num_frames,
1102
            image_processor=video_processor,
1103
            mm_kwargs={},
1104
        )
1105
1106
        # NOTE: we need to do this check here since Qwen3-VL resizes video
        # frames depending on how many frames there are.
1107
1108
1109
1110
        target_video_width, target_video_height = (
            target_video_size.width,
            target_video_size.height,
        )
1111
1112
1113
1114
        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            width_override = video_overrides.width
            if width_override:
1115
                if width_override > target_video_width:
1116
1117
                    logger.warning(
                        "video.width override (%d) exceeds model's "
1118
1119
                        "maximum width (%d), will be ignored",
                        width_override,
1120
                        target_video_width,
1121
                    )
1122
                target_video_width = min(target_video_width, width_override)
1123
1124
            height_override = video_overrides.height
            if height_override:
1125
                if height_override > target_video_height:
1126
1127
1128
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
1129
                        height_override,
1130
                        target_video_height,
1131
                    )
1132
                target_video_height = min(target_video_height, height_override)
1133

1134
        return {
1135
            "image": self._get_dummy_images(
1136
1137
                width=target_image_width,
                height=target_image_height,
1138
1139
1140
1141
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1142
1143
                width=target_video_width,
                height=target_video_height,
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
1156
        overrides: VideoDummyOptions | None = None,
1157
    ) -> list[VideoItem]:
1158
1159
1160
1161
1162
1163
1164
1165
1166
        videos = super()._get_dummy_videos(
            width=width,
            height=height,
            num_frames=num_frames,
            num_videos=num_videos,
            overrides=overrides,
        )
        videos = [v.copy() for v in videos]

1167
        video_items = []
1168
1169
        for video in videos:
            video_num_frames = video.shape[0]
1170
1171
            video_metadata = {
                "fps": 2.0,
1172
1173
1174
                "duration": video_num_frames / 2.0,
                "total_num_frames": video_num_frames,
                "frames_indices": list(range(video_num_frames)),
1175
1176
1177
                "video_backend": "opencv",
                "do_sample_frames": False,
            }
1178
1179
            video_items.append((video, video_metadata))

1180
1181
1182
        return video_items


1183
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # Separate video processing from image processing. Because the videos
1195
1196
        # are processed into several image patches
        if videos := mm_data.pop("videos", []):
1197
1198
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
1199
            timestamps_per_video = []
1200

1201
            for item in videos:
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
                video_array, metadata = item

                # NOTE: @JJJYmmm new attr metadata.frames_indices indicates
                # the sampled frames indices of pre-sampled videos, which is
                # used to calculate the timestamps. Make sure that
                # do_sample_frames in mm_kwargs is false for presampled videos.

                # NOTE: a copy of is created to update do_sample_frames,
                # otherwise mm_hash for the object will be incorrect.
                video_mm_kwargs = dict(**mm_kwargs)
                if "do_sample_frames" not in video_mm_kwargs:
                    # qwen_vl_utils already has "do_sample_frames" in
                    # mm_kwargs, don't overwrite it.
                    video_mm_kwargs["do_sample_frames"] = metadata.get(
1216
1217
                        "do_sample_frames", False
                    )
1218

1219
1220
1221
                metadata = VideoMetadata(
                    **{k: metadata[k] for k in metadata if k != "do_sample_frames"}
                )
1222

1223
1224
1225
1226
1227
                # Compute timestamps here where we have access to metadata
                timestamps = self.info._get_video_second_idx(
                    metadata=metadata,
                    do_sample_frames=video_mm_kwargs["do_sample_frames"],
                    sampled_fps=video_mm_kwargs.get("fps"),
1228
                    sampled_num_frames=video_mm_kwargs.get("num_frames"),
1229
1230
1231
                )
                timestamps_per_video.append(timestamps)

1232
1233
1234
1235
                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
                video_mm_data["video_metadata"] = [[metadata]]

1236
1237
1238
1239
1240
1241
1242
                # When num_frames is specified, explicitly set fps=None
                # to prevent HF's BaseVideoProcessor.preprocess() from
                # filling in the class default (fps=2) via setdefault(),
                # which would conflict with num_frames (mutually exclusive).
                if "num_frames" in video_mm_kwargs and "fps" not in video_mm_kwargs:
                    video_mm_kwargs["fps"] = None

1243
1244
1245
1246
1247
1248
                video_outputs = super()._call_hf_processor(
                    prompt="<|vision_start|><|video_pad|><|vision_end|>",
                    mm_data=video_mm_data,
                    mm_kwargs=video_mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291

                merge_size = processor.video_processor.merge_size
                # Get video grid info for EVS calculation.
                video_grid_thw = video_outputs["video_grid_thw"]
                num_frames = int(video_grid_thw[0, 0])
                tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // (
                    merge_size**2
                )

                # Apply EVS if enabled.
                video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
                if video_pruning_rate is not None and video_pruning_rate > 0.0:
                    num_tokens = compute_retained_tokens_count(
                        tokens_per_frame=tokens_per_frame_base,
                        num_frames=num_frames,
                        q=video_pruning_rate,
                    )
                    # Here we just need placeholders that won't actually be replaced -
                    # we just need to make sure the total number of tokens is correct
                    # assign all tokens to the first frame.
                    tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
                    select_token_id = False
                else:
                    tokens_per_frame = [tokens_per_frame_base] * num_frames
                    select_token_id = True

                # Generate the video replacement with EVS-adjusted token counts
                tokenizer = self.info.get_tokenizer()
                hf_config = self.info.get_hf_config()
                video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
                    tokens_per_frame=tokens_per_frame,
                    timestamps=timestamps,
                    tokenizer=tokenizer,
                    vision_start_token_id=hf_config.vision_start_token_id,
                    vision_end_token_id=hf_config.vision_end_token_id,
                    video_token_id=hf_config.video_token_id,
                    select_token_id=select_token_id,
                )

                # Convert token IDs to text for the HF processor flow
                video_placeholder = tokenizer.decode(
                    video_repl.full, skip_special_tokens=False
                )
1292
                input_ids = video_outputs.pop("input_ids")
1293
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
1294
1295
1296
1297
1298
1299
1300
                prompt = prompt.replace(
                    "<|vision_start|><|video_pad|><|vision_end|>",
                    video_placeholder,
                    1,
                )

                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
1301
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
1302
1303
1304
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
1305
                timestamps=timestamps_per_video,
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1327
1328
1329
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)
1330
1331
1332
1333
1334
1335
1336
1337

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1338
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
        tokenizer = self.info.get_tokenizer()
        hf_config = self.info.get_hf_config()

        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        vision_end_token_id = hf_config.vision_end_token_id

        merge_length = image_processor.merge_size**2

        def get_image_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            sampled_fps = hf_processor_mm_kwargs.get("fps")
            if is_list_of(sampled_fps, float):
                sampled_fps = sampled_fps[item_idx]

1365
            timestamps = out_item["timestamps"].data
1366
1367
            assert len(timestamps) == grid_thw[0], (
                f"The timestamps length({len(timestamps)}) should be equal "
1368
1369
                f"video length ({grid_thw[0]})."
            )
1370

1371
1372
1373
            # Compute tokens per frame, with EVS support
            num_frames = int(grid_thw[0])
            tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length
1374
1375
1376

            video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
            if video_pruning_rate is not None and video_pruning_rate > 0.0:
1377
1378
1379
1380
                num_tokens = compute_retained_tokens_count(
                    tokens_per_frame=tokens_per_frame_base,
                    num_frames=num_frames,
                    q=video_pruning_rate,
1381
                )
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
                tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
                select_token_id = False
            else:
                tokens_per_frame = [tokens_per_frame_base] * num_frames
                select_token_id = True

            return Qwen3VLMultiModalProcessor.get_video_repl(
                tokens_per_frame=tokens_per_frame,
                timestamps=timestamps,
                tokenizer=tokenizer,
                vision_start_token_id=vision_start_token_id,
                vision_end_token_id=vision_end_token_id,
                video_token_id=video_token_id,
                select_token_id=select_token_id,
            )
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_qwen3vl,
            ),
            # NOTE: We match string on purpose since searching sequence of
            # token ids takes more time.
            PromptReplacement(
                modality="video",
                target="<|vision_start|><|video_pad|><|vision_end|>",
                replacement=get_video_replacement_qwen3vl,
            ),
        ]

1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
    @staticmethod
    def get_video_repl(
        *,
        tokens_per_frame: list[int],
        timestamps: list[float | int],
        tokenizer: TokenizerLike,
        vision_start_token_id: int,
        vision_end_token_id: int,
        video_token_id: int,
        select_token_id: bool = False,
    ) -> PromptUpdateDetails[list[int]]:
        """Build prompt replacement for a video in Qwen3VL format.

        The replacement structure for each frame is:
        timestamp_tokens + vision_start_token + video_tokens + vision_end_token

        Args:
            tokens_per_frame: Number of video tokens per frame (can vary per frame for
                EVS).
            timestamps: List of timestamps in seconds for each frame
            tokenizer: Tokenizer to encode timestamp strings
            vision_start_token_id: Token ID for vision start marker
            vision_end_token_id: Token ID for vision end marker
            video_token_id: Token ID for video content

        Returns:
            PromptUpdateDetails with full token sequence
        """
        assert len(timestamps) == len(tokens_per_frame), (
            "timestamps and tokens_per_frame must have the same length"
        )

        # Tokenize timestamp strings independently to avoid tokenizer merging
        # tokens across boundaries.
        # TODO: switch to `_seq2tokens` which has some caching.
        timestamp_token_ids = [
            tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False)
            for timestamp in timestamps
        ]

        # Build the full token sequence
        all_token_ids = []
        for frame_timestamp_ids, num_tokens in zip(
            timestamp_token_ids, tokens_per_frame
        ):
            # Add timestamp tokens
            all_token_ids.extend(frame_timestamp_ids)

            # Add vision tokens: vision_start + video_tokens + vision_end
            all_token_ids.append(vision_start_token_id)
            all_token_ids.extend([video_token_id] * num_tokens)
            all_token_ids.append(vision_end_token_id)

        if select_token_id:
            return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)

        # NOTE: we use `from_seq` instead of `select_token_id` because we want all
        # tokens in the placeholder to be initially marked as candidates. Then
        # in `get_input_embeddings``, we refine the mask to only replace
        # `video_token_id` / `image_token_id`` positions with video/image embeddings,
        # keeping text embeddings for timestamps and structural tokens.
        return PromptUpdateDetails.from_seq(all_token_ids)

1476
1477
1478
1479
1480
1481
1482
1483
1484
1485

@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        # the same shape as input_embeds
1486
1487
1488
        "deepstack_input_embeds": 0,
    }
)
1489
1490
1491
class Qwen3LLMModel(Qwen3Model):
    def forward(
        self,
1492
        input_ids: torch.Tensor | None,
1493
        positions: torch.Tensor,
1494
1495
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1496
        # args for deepstack
1497
1498
        deepstack_input_embeds: IntermediateTensors | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1499
1500
1501
1502
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1503
                hidden_states = self.embed_input_ids(input_ids)
1504
1505
1506
1507
1508
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
1509

1510
        aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
1511
1512
        for layer_idx, layer in islice(
            enumerate(self.layers), self.start_layer, self.end_layer
1513
        ):
1514
1515
1516
1517
1518
1519
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

1520
1521
1522
1523
1524
1525
1526
            if deepstack_input_embeds is not None and layer_idx in range(
                0, len(deepstack_input_embeds)
            ):
                hidden_states = (
                    hidden_states
                    + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
                )
1527
1528
1529
            self._maybe_add_hidden_state(
                aux_hidden_states, layer_idx + 1, hidden_states, residual
            )
1530
1531

        if not get_pp_group().is_last_rank:
1532
1533
1534
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1535
        hidden_states, _ = self.norm(hidden_states, residual)
1536
1537
1538

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
1539
1540
1541
1542
1543
1544
        return hidden_states


class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super(Qwen3ForCausalLM, self).__init__()
1545
        config = vllm_config.model_config.hf_config
1546
1547
1548
1549
1550
        quant_config = vllm_config.quant_config

        self.config = config

        self.quant_config = quant_config
1551
1552
1553
        self.model = Qwen3LLMModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1554
1555
1556
1557
1558

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
1559
1560
1561
1562
1563
1564
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix="lm_head",
                )
1565
1566
1567
1568
1569
1570
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
1571
1572
            self.model.make_empty_intermediate_tensors
        )
1573
1574


1575
1576
1577
1578
1579
1580
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3VLMultiModalProcessor,
    info=Qwen3VLProcessingInfo,
    dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
1581
1582
    nn.Module,
    SupportsMultiModal,
1583
    SupportsEncoderCudaGraph,
1584
1585
1586
    SupportsLoRA,
    SupportsPP,
    SupportsMRoPE,
1587
    SupportsEagle,
1588
    SupportsEagle3,
1589
    SupportsMultiModalPruning,
1590
):
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
1601
        "qkv": ["qkv"],  # For vision tower's already-packed QKV
1602
    }
1603
1604
1605

    supports_encoder_tp_data = True

1606
1607
1608
1609
1610
1611
    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.visual.": "visual.",
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
1612
1613
        }
    )
1614
1615

    @classmethod
1616
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: Qwen3VLConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
1631
        self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
1632
        self.multimodal_config = multimodal_config
1633
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1634
1635
1636
1637
1638
        self.video_pruning_rate = multimodal_config.video_pruning_rate
        self.is_multimodal_pruning_enabled = (
            multimodal_config.is_multimodal_pruning_enabled()
        )

1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
        self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
        self.deepstack_num_level = (
            len(config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
        self.visual_dim = config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level

        with self._mark_tower_model(vllm_config, {"image", "video"}):
1649
1650
1651
1652
1653
1654
            self.visual = Qwen3_VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )
1655

1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
            # register buffer for deepstack
            if self.use_deepstack:
                self.deepstack_input_embeds = [
                    torch.zeros(
                        vllm_config.scheduler_config.max_num_batched_tokens,
                        config.text_config.hidden_size,
                    )
                    for _ in range(self.deepstack_num_level)
                ]

        with self._mark_language_model(vllm_config):
            self.language_model = Qwen3LLMForCausalLM(
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model"),
            )

        if not get_pp_group().is_first_rank and hasattr(
            config.vision_config, "deepstack_visual_indexes"
        ):
            assert self.language_model.start_layer >= len(
                config.vision_config.deepstack_visual_indexes
            ), (
                "start_layer should be greater than or equal to "
                "len(deepstack_visual_indexes)"
1680
            )
1681
1682

        self.make_empty_intermediate_tensors = (
1683
1684
            self.language_model.make_empty_intermediate_tensors
        )
1685

1686
1687
1688
1689
1690
1691
1692
    def _get_deepstack_input_embeds(
        self,
        num_tokens: int,
    ) -> IntermediateTensors | None:
        if not getattr(self, "deepstack_input_embeds", None):
            return None  # If vision tower is skipped

1693
        # get deepstack_input_embeds from buffer, and clear the buffer
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
1704
1705
1706
        if not getattr(self, "deepstack_input_embeds", None):
            return

1707
1708
1709
1710
        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
1711
1712
1713
1714
1715
1716
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
1717
1718
1719
1720
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
1721
1722
                deepstack_input_embeds[idx]
            )
1723
1724

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
1725
1726
1727
        if not getattr(self, "deepstack_input_embeds", None):
            return

1728
1729
1730
1731
1732
        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
    # -- SupportsEncoderCudaGraph protocol methods --

    def get_encoder_cudagraph_config(self):
        from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
            EncoderCudaGraphConfig,
        )

        return EncoderCudaGraphConfig(
            modalities=["image"],
            input_key="pixel_values",
            buffer_keys=[
                "pos_embeds",
                "rotary_pos_emb_cos",
                "rotary_pos_emb_sin",
                "cu_seqlens",
                "max_seqlen",
                "sequence_lengths",
            ],
            out_hidden_size=self.visual.out_hidden_size,
        )

    def get_encoder_cudagraph_budget_range(
        self,
        vllm_config,
    ) -> tuple[int, int]:
        # Min: estimated smallest possible encoder input.
        # 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
        min_budget = 64
        # Max: capped by max_num_batched_tokens
        max_budget = vllm_config.scheduler_config.max_num_batched_tokens
        return (min_budget, max_budget)

    def get_encoder_cudagraph_num_items(
        self,
        mm_kwargs: dict[str, Any],
    ) -> int:
        return len(mm_kwargs["image_grid_thw"])

    def get_encoder_cudagraph_per_item_output_tokens(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[int]:
        m = self.visual.spatial_merge_size
        return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]

    def get_encoder_cudagraph_per_item_input_sizes(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[int]:
        return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]

    def select_encoder_cudagraph_items(
        self,
        mm_kwargs: dict[str, Any],
        indices: list[int],
    ) -> dict[str, Any]:
        grid_thw = mm_kwargs["image_grid_thw"]
        pixel_values = mm_kwargs["pixel_values"]

        if len(indices) == 0:
            return {
                "pixel_values": pixel_values[:0],
                "image_grid_thw": [],
            }

        # Compute cumulative patch offsets for slicing pixel_values
        patches_per_item = [t * h * w for t, h, w in grid_thw]
        cum_patches = [0]
        for p in patches_per_item:
            cum_patches.append(cum_patches[-1] + p)

        selected_pv = torch.cat(
            [pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
        )
        selected_grid = [grid_thw[i] for i in indices]

        return {
            "pixel_values": selected_pv,
            "image_grid_thw": selected_grid,
        }

    def prepare_encoder_cudagraph_capture_inputs(
        self,
        token_budget: int,
        max_batch_size: int,
        device: torch.device,
        dtype: torch.dtype,
    ):
        from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
            EncoderCudaGraphCaptureInputs,
        )

        spatial_merge_size = self.visual.spatial_merge_size
        per_image_output = token_budget // max_batch_size

        # Synthetic rectangular grid: [1, merge, per_image_output * merge]
        # produces exactly per_image_output tokens per image.
        grid_config = [
            [1, spatial_merge_size, per_image_output * spatial_merge_size]
            for _ in range(max_batch_size)
        ]

        # Create dummy pixel_values
        patch_embed = self.visual.patch_embed
        in_channels = patch_embed.proj.in_channels
        patch_size = patch_embed.patch_size
        temporal_patch_size = patch_embed.temporal_patch_size
        total_patches = sum(t * h * w for t, h, w in grid_config)
        flattened_patch_size = (
            in_channels * temporal_patch_size * patch_size * patch_size
        )
        dummy_pixel_values = torch.randn(
            total_patches, flattened_patch_size, device=device, dtype=dtype
        )

        # Override max_seqlen with a safe upper bound for capture.
        # max_seqlen.item() gets baked into the CUDA graph (not replayed),
        # so the capture value must cover any replay scenario.
        # Worst case: 1 image consuming the full budget ->
        # seq_len = token_budget * spatial_merge_size^2.
        buffers = self.visual.prepare_encoder_metadata(
            grid_config,
            max_batch_size=max_batch_size,
            max_seqlen_override=token_budget * (spatial_merge_size**2),
            device=device,
        )

        mm_kwargs = {
            "pixel_values": dummy_pixel_values,
            "image_grid_thw": grid_config,
        }

        return EncoderCudaGraphCaptureInputs(
            mm_kwargs=mm_kwargs,
            buffers=buffers,
        )

    def prepare_encoder_cudagraph_replay_buffers(
        self,
        mm_kwargs: dict[str, Any],
        max_batch_size: int,
    ):
        from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
            EncoderCudaGraphReplayBuffers,
        )

        grid_thw_list = mm_kwargs["image_grid_thw"]

        buffers = self.visual.prepare_encoder_metadata(
            grid_thw_list,
            max_batch_size=max_batch_size,
        )

        return EncoderCudaGraphReplayBuffers(buffers=buffers)

    def encoder_cudagraph_forward(
        self,
        mm_kwargs: dict[str, Any],
        buffers: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        pixel_values = mm_kwargs["pixel_values"]
        grid_thw = mm_kwargs["image_grid_thw"]
        return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)

    def encoder_eager_forward(
        self,
        mm_kwargs: dict[str, Any],
    ) -> torch.Tensor:
        pixel_values = mm_kwargs["pixel_values"]
        grid_thw = mm_kwargs["image_grid_thw"]
        return self.visual(pixel_values, grid_thw)

1905
    def _parse_and_validate_image_input(
1906
        self, **kwargs: object
1907
    ) -> Qwen2_5_VLImageInputs | None:
1908
1909
1910
1911
1912
1913
1914
1915
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
1916
1917
1918
1919
1920
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1921
1922
1923
1924
1925

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
1926
1927
                image_grid_thw=image_grid_thw,
            )
1928
1929

    def _parse_and_validate_video_input(
1930
        self, **kwargs: object
1931
    ) -> Qwen2_5_VLVideoInputs | None:
1932
1933
1934
1935
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
1936
        timestamps = kwargs.pop("timestamps", None)
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
1947
                timestamps=timestamps,
1948
1949
1950
1951
1952
1953
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
1954
                video_grid_thw=video_grid_thw,
1955
                timestamps=timestamps,
1956
            )
1957
1958

    def _process_image_input(
1959
1960
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1961
1962
1963
1964
1965
1966
1967
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1968
            if self.use_data_parallel:
1969
                return run_dp_sharded_mrope_vision_model(
1970
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1971
                )
1972
            else:
1973
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1974
1975
1976

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1977
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1978
1979
1980
        return image_embeds.split(sizes)

    def _process_video_input(
1981
1982
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1983
1984
1985
1986
1987
1988
1989
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1990
1991
                self.visual.dtype
            )
1992
            if self.use_data_parallel:
1993
                grid_thw_list = grid_thw.tolist()
1994
1995
1996
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1997
            else:
1998
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1999
2000
2001

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
2002
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
2003
2004
        return video_embeds.split(sizes)

2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
    def _postprocess_image_embeds_evs(
        self,
        image_embeds_split: tuple[torch.Tensor, ...],
        image_input: Qwen2_5_VLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Append mrope positions for each for images.
        This is necessary to recover correct mrope
        positions after video pruning

        Args:
            image_embeds_split: Tuple of image embeddings for
                each image item.
            image_input: Image input data.

        Returns:
            Tuple of image embeddings for each image item.
2022
2023
            Resulting embeddings will have extra 5 channels for
            computed mrope positions, consistent with video embeddings.
2024
        """
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
        if self.is_multimodal_pruning_enabled:
            merge_size = self.visual.spatial_merge_size
            grid_thw = image_input["image_grid_thw"]
            grid_thw_list = grid_thw.tolist()
            image_embeds_out = []
            for emb, size in zip(image_embeds_split, grid_thw_list):
                positions = compute_mrope_for_media(size, merge_size).to(emb.device)
                positions = torch.cat(
                    [
                        positions,
                        torch.zeros_like(
                            positions[:, 0:1]
                        ),  # Dummy extra fifth channel
                    ],
                    dim=1,
                )
                emb = torch.cat([emb, positions], dim=1)
                image_embeds_out.append(emb)
            image_embeds_split = tuple(image_embeds_out)
        return image_embeds_split
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060

    def _postprocess_video_embeds_evs(
        self,
        video_embeds_split: tuple[torch.Tensor, ...],
        video_input: Qwen2_5_VLVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Prunes video embeddings via Efficient Video Sampling (EVS)
        and then appends mrope positions for each retained embeddings

        Args:
            video_embeds_split: Tuple of video embeddings for each video item.
            video_input: Video input data.

        Returns:
            Tuple of video embeddings for each video item.
2061
2062
            Resulting embeddings will have extra 5 channels for computed mrope
            positions, and whether the index corresponds to a video embedding.
2063
2064
2065
2066
2067
2068
        """
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()
        merge_size = self.visual.spatial_merge_size

2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
        # Apply EVS to each video.
        video_embeds_out = []
        for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)):
            # Compute positions.
            timestamps = video_input.timestamps[video_idx]
            num_frames = len(timestamps)

            t, h, w = size
            if self.is_multimodal_pruning_enabled:
                # For each video, compute retention mask using EVS.
                # retention_mask: [11424].
                retention_mask = compute_retention_mask(
                    emb,
                    size,
                    spatial_merge_size=self.visual.spatial_merge_size,
                    q=self.video_pruning_rate,
                )
                # Apply retention mask.
                emb = emb[retention_mask]

                # Calculate the actual number of retained tokens per frame.
                num_frames, rows, cols = (
                    t,
                    h // merge_size,
                    w // merge_size,
                )
                retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
                num_tokens_per_frame = (
                    retention_mask_thw.sum(dim=(1, 2)).long().tolist()
                )
            else:
                feature_size = emb.shape[0] // num_frames
                num_tokens_per_frame = [feature_size] * num_frames
                retention_mask = None

            emb = self._create_final_video_embeddings(
                video_embeddings=emb,
                num_tokens_per_frame=num_tokens_per_frame,
                timestamps=timestamps,
                video_grid_thw=size,
                retention_mask=retention_mask,
            )

            video_embeds_out.append(emb)

        return tuple(video_embeds_out)

    def _create_final_video_embeddings(
        self,
        video_embeddings: torch.Tensor,
        num_tokens_per_frame: list[int],
        timestamps: list[float],
        video_grid_thw: list[int],
        retention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Create final embeddings that combine video embeddings with
        text embeddings of indicator tokens.

        These final embeddings contain:
        - Actual video embeddings in positions corresponding to video content
        - Text embeddings for indicator tokens (<img>, </img>, and
          frame separation text) in their respective positions

        These embeddings will replace the placeholder embeddings to create
        input_embeds for the LLM.
        """

        # Generate video replacement token IDs using get_video_repl
        # This tokenizes each frame separator independently, then uses pre-tokenized
        # special tokens to ensure consistent tokenization regardless of
        # num_tokens_per_frame values.
        video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
            tokens_per_frame=num_tokens_per_frame,
            tokenizer=self._tokenizer,
            timestamps=timestamps,
            vision_start_token_id=self.config.vision_start_token_id,
            vision_end_token_id=self.config.vision_end_token_id,
            video_token_id=self.config.video_token_id,
            select_token_id=self.is_multimodal_pruning_enabled,
        )

2150
2151
2152
2153
        repl_token_ids = torch.tensor(video_repl.full)
        embed_token_id = _cached_tensor(
            self.config.video_token_id, repl_token_ids.device
        )
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
        is_video_embed = torch.isin(repl_token_ids, embed_token_id)

        # Get text embeddings for indicator tokens (has only `visual_dim``).
        text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)

        if self.use_deepstack:
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=text_embeddings,
                multimodal_embeddings=[video_embeddings],
                is_multimodal=is_video_embed,
            )
2168
        else:
2169
2170
            deepstack_input_embeds = None
            multimodal_embeddings = [video_embeddings]
2171

2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
        merged_embeddings = _merge_multimodal_embeddings(
            inputs_embeds=text_embeddings,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_video_embed,
        )

        to_concat = [merged_embeddings]
        if deepstack_input_embeds is not None:
            to_concat.append(
                deepstack_input_embeds.permute(1, 0, 2).reshape(
                    deepstack_input_embeds.shape[1], -1
                )
2184
2185
            )

2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
        expanded_positions = None
        if self.is_multimodal_pruning_enabled:
            is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id)
            expanded_positions = self._get_expanded_positions(
                device=merged_embeddings.device,
                seq_len=merged_embeddings.shape[0],
                video_grid_thw=video_grid_thw,
                num_tokens_per_frame=num_tokens_per_frame,
                timestamps=timestamps,
                is_video_embed=is_video_embed,
                is_vision_start=is_vision_start,
                retention_mask=retention_mask,
2198
            )
2199
            to_concat.append(expanded_positions)
2200

2201
        final_video_embeddings = torch.cat(to_concat, dim=-1)
2202

2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
        return final_video_embeddings

    def _get_expanded_positions(
        self,
        device,
        seq_len,
        video_grid_thw,
        num_tokens_per_frame,
        timestamps,
        is_video_embed,
        is_vision_start,
        retention_mask,
    ):
        embed_token_id = _cached_tensor(self.config.video_token_id, device=device)

        # Expand positions to match the full sequence length
        # (includes both video tokens and indicator tokens)
        # Shape: [full_length, 5] where positions are filled for video tokens
        # and zeros for indicator tokens.
        # Channel 3 flags VISION_START tokens so that
        # recompute_mrope_positions can reliably count timestamp tokens
        # (even when early frames have all video tokens pruned).
        # Channel 4 flags video-embedding tokens.
        expanded_positions = torch.zeros(
            seq_len,
            5,  # [t_index, h_index, w_index, is_vision_start, is_video]
            device=device,
            dtype=torch.long,
        )
        _, h, w = video_grid_thw
        merge_size = self.visual.spatial_merge_size
        num_frames = len(num_tokens_per_frame)
        unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl(
            tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames,
            tokenizer=self._tokenizer,
            timestamps=timestamps,
            vision_start_token_id=self.config.vision_start_token_id,
            vision_end_token_id=self.config.vision_end_token_id,
            video_token_id=self.config.video_token_id,
        ).full
        unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device)
        mm_feature = MultiModalFeatureSpec(
            data=MultiModalKwargsItem(
                {
                    "video_grid_thw": MultiModalFieldElem(
                        data=torch.tensor(video_grid_thw),
                        field=None,  # HACK.
                    ),
                }
            ),
            modality="video",
            identifier="DUMMY",
            mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)),
        )
        original_mrope = (
            self.get_mrope_input_positions(
                input_tokens=unpruned_token_ids,
                mm_features=[mm_feature],
            )[0]
            .to(device)
            .permute(1, 0)
        )
        full_is_video_embed = unpruned_token_ids_tensor == embed_token_id
        expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][
            retention_mask
        ]
        expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed]
        expanded_positions[..., 3] = is_vision_start
        expanded_positions[..., 4] = is_video_embed

        return expanded_positions
2274

2275
2276
2277
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}
        for input_key in kwargs:
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
2292
2293
        return mm_input_by_modality

2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
    @staticmethod
    def _iter_mm_grid_hw(
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
        video_token_id: int,
        vision_start_token_id: int,
        vision_end_token_id: int,
        spatial_merge_size: int,
    ) -> Iterator[tuple[int, int, int, int]]:
        """Iterate over multimodal features and yield position info.
2304
2305

        Args:
2306
2307
2308
2309
2310
2311
2312
2313
            input_tokens: List of token IDs in the input sequence.
            mm_features: List of multimodal feature specifications containing
                image/video data and position information.
            video_token_id: Token ID used for video tokens.
            vision_start_token_id: Token ID marking the start of a vision sequence.
            vision_end_token_id: Token ID marking the end of a vision sequence.
            spatial_merge_size: Size of the spatial merge operation used to
                compute logical grid dimensions from the original feature grid.
2314
2315

        Yields:
2316
2317
2318
2319
            offset: Position of the first video/image token in the sequence.
            llm_grid_h: Logical grid height (may not match actual token count with EVS).
            llm_grid_w: Logical grid width (may not match actual token count with EVS).
            actual_num_tokens: Actual number of video/image tokens in the placeholder.
2320
        """
2321
2322
2323
2324
2325
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
2326
2327
2328
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size
                yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w
2329
2330
2331
2332
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size
2333

2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
                for _ in range(t):
                    # When EVS is enabled, some frames may have 0 video tokens in the
                    # placeholder. We use `vision_start_token_id` to locate each frame
                    # since it is always present for every frame.
                    # We then look for the first `video_token_id` after
                    # `vision_start_token_id` and before `vision_end_token_id`.
                    offset = input_tokens.index(vision_start_token_id, offset)
                    vision_end_offset = input_tokens.index(vision_end_token_id, offset)

                    try:
                        actual_num_tokens = 0
                        video_offset = input_tokens.index(
                            video_token_id, offset, vision_end_offset
                        )
                        # NOTE: looking at the
                        # `Qwen3VLMultiModalProcessor.get_video_repl` code, we can
                        # see that we can use the below formula to get the token
                        # count, since everything in between `video_offset` and
                        # `vision_end_offset` is populated as `video_token_id`.
                        # This saves us from manually counting the number tokens
                        # that match `video_token_id` in between.
                        actual_num_tokens += vision_end_offset - video_offset
                    except ValueError:
                        # No `video_token_id` in this frame (EVS with 0 tokens for
                        # this frame) -> use `offset + 1`` to move past
                        # `vision_start_token_id`.
                        video_offset = offset + 1

                    yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens
                    # Move offset past this frame for next iteration.
                    offset = vision_end_offset + 1
2365
2366
2367
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        return self._get_mrope_input_positions(
            input_tokens=input_tokens,
            mm_features=mm_features,
            config=self.config,
        )

    @staticmethod
    def _get_mrope_input_positions(
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
        config: Qwen3VLConfig,
    ):
        llm_pos_ids_list = []
        st = 0
        for (
            offset,
            llm_grid_h,
            llm_grid_w,
            actual_num_tokens,
        ) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw(
            input_tokens,
            mm_features,
            video_token_id=config.video_token_id,
            vision_start_token_id=config.vision_start_token_id,
            vision_end_token_id=config.vision_end_token_id,
            spatial_merge_size=config.vision_config.spatial_merge_size,
        ):
            # Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere)
            if actual_num_tokens == 0:
                continue

            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

            # Check if this is a "lumped placeholder" (all tokens from multiple frames
            # assigned to the 0-th frame - see
            # `Qwen3VLMultiModalProcessor.get_video_repl`.
            expected_tokens_per_frame = llm_grid_h * llm_grid_w
            if actual_num_tokens > expected_tokens_per_frame:
                # Lumped placeholder: create grid positions for all "logical" frames
                # represented.
                num_logical_frames = actual_num_tokens // expected_tokens_per_frame
                remainder = actual_num_tokens % expected_tokens_per_frame

                # Create positions for complete frames.
                for _ in range(num_logical_frames):
                    grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(
                        3, -1
                    )
                    llm_pos_ids_list.append(grid_indices + text_len + st_idx)
                    st_idx = llm_pos_ids_list[-1].max() + 1
                    text_len = 0  # No text between frames within the lump

                # Handle remainder tokens if any (partial frame).
                # NOTE: this should never be the case. Should we have an assert?
                if remainder > 0:
                    # Create a partial grid - take first 'remainder' positions
                    full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
                    grid_indices = full_grid[:, :remainder]
                    llm_pos_ids_list.append(grid_indices + text_len + st_idx)
            else:
                # Normal case: frame has exactly the expected tokens (after actual EVS
                # pruning).
                grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
                llm_pos_ids_list.append(grid_indices + text_len + st_idx)

            st = offset + actual_num_tokens

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
        return torch.from_numpy(llm_positions), mrope_position_delta

2455
2456
2457
    def recompute_mrope_positions(
        self,
        input_ids: list[int],
2458
        multimodal_embeddings: MultiModalEmbeddings,
2459
2460
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
2461
    ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
2462
2463
2464
2465
2466
2467
2468
2469
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
2470
2471
2472
2473
            input_ids: (N,) All input tokens of the prompt containing
                entire sequence.
            multimodal_embeddings: Tuple of multimodal embeddings that
                fits into the prefill chunk that is being processed.
2474
2475
2476
2477
2478
2479
2480
2481
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
2482
2483
2484
2485
2486
2487
2488
2489
2490
        return self._recompute_mrope_positions(
            input_ids=input_ids,
            multimodal_embeddings=multimodal_embeddings,
            mrope_positions=mrope_positions,
            num_computed_tokens=num_computed_tokens,
            image_token_id=self.config.image_token_id,
            video_token_id=self.config.video_token_id,
            vision_start_token_id=self.config.vision_start_token_id,
        )
2491

2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
    @staticmethod
    def _recompute_mrope_positions(
        input_ids: list[int],
        multimodal_embeddings: MultiModalEmbeddings,
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
        vision_start_token_id: int,
        image_token_id: int,
        video_token_id: int,
    ) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
        # Device
        device = (
            multimodal_embeddings[0].device
            if len(multimodal_embeddings)
            else mrope_positions.device
        )

        # Tensors
        input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
        mm_embeddings_out = []
        mm_embeddings_pos = []
        # Strip position information from embeddings (last 5 channels)
        # For Qwen3 VL, handle potentially empty frames (from unpacking)
        for mm in multimodal_embeddings:
            if mm.shape[0] > 0:  # Only process non-empty frames
                mm_embeddings_out.append(mm[:, :-5])
                mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long())
            else:
                # Empty frame - keep as is
                mm_embeddings_out.append(mm)
                # Create empty position tensor with correct shape
                mm_embeddings_pos.append(
                    torch.empty(5, 0, device=device, dtype=torch.long)
                )
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539

        positions, mrope_positions_delta = recompute_mrope_positions(
            input_ids_t,
            mm_embeddings_pos,
            mrope_positions,
            num_computed_tokens,
            vision_start_token_id,
            image_token_id,
            video_token_id,
        )

        return tuple(mm_embeddings_out), positions, mrope_positions_delta

2540
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
2541
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
2542
2543
2544
2545
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
2546
2547
        # tensor corresponding to a multimodal data item (image or video).
        multimodal_embeddings: list[torch.Tensor] = []
2548
2549
2550
2551
2552
2553

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
2554
                image_embeddings = self._process_image_input(multimodal_input)
2555
2556
2557
2558
                image_embeddings = self._postprocess_image_embeds_evs(
                    image_embeddings, multimodal_input
                )
                multimodal_embeddings.extend(image_embeddings)
2559
2560
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
2561
2562
2563
2564
                if self.is_multimodal_pruning_enabled:
                    video_embeddings = self._postprocess_video_embeds_evs(
                        video_embeddings, multimodal_input
                    )
2565
2566
2567
2568
                multimodal_embeddings.extend(video_embeddings)

        embeddings_tuple = tuple(multimodal_embeddings)
        return embeddings_tuple
2569
2570

    def _compute_deepstack_embeds(
2571
2572
2573
2574
2575
2576
        self,
        inputs_embeds: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings,
        is_multimodal: torch.Tensor,
    ) -> tuple[torch.Tensor, MultiModalEmbeddings]:
        visual_lens = [len(x) for x in multimodal_embeddings]
2577
2578
        multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

2579
2580
2581
2582
2583
2584
2585
2586
        (
            multimodal_embeddings_main,
            multimodal_embeddings_multiscale,
        ) = torch.split(
            multimodal_embeddings_cat,
            [self.visual_dim, self.multiscale_dim],
            dim=-1,
        )
2587

2588
2589
2590
        multimodal_embeddings = torch.split(
            multimodal_embeddings_main, visual_lens, dim=0
        )
2591
        multimodal_embeddings_multiscale = torch.split(
2592
2593
            multimodal_embeddings_multiscale, visual_lens, dim=0
        )
2594
2595

        deepstack_input_embeds = inputs_embeds.new_zeros(
2596
2597
            inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
        )
2598

2599
2600
2601
2602
        deepstack_input_embeds = _merge_multimodal_embeddings(
            inputs_embeds=deepstack_input_embeds,
            multimodal_embeddings=multimodal_embeddings_multiscale,
            is_multimodal=is_multimodal,
2603
2604
        )
        deepstack_input_embeds = deepstack_input_embeds.view(
2605
2606
            inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
        )
2607
        deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
2608

2609
2610
        return deepstack_input_embeds, multimodal_embeddings

2611
    def embed_input_ids(
2612
2613
        self,
        input_ids: torch.Tensor,
2614
        multimodal_embeddings: MultiModalEmbeddings | None = None,
2615
        *,
2616
        is_multimodal: torch.Tensor | None = None,
2617
    ) -> torch.Tensor:
2618
        inputs_embeds = self._embed_text_input_ids(
2619
            input_ids,
2620
            self.language_model.embed_input_ids,
2621
2622
2623
2624
2625
2626
            is_multimodal=is_multimodal,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

2627
        is_multimodal = _require_is_multimodal(is_multimodal)
2628
2629

        if self.use_deepstack:
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=inputs_embeds,
                multimodal_embeddings=multimodal_embeddings,
                is_multimodal=is_multimodal,
            )
        else:
            deepstack_input_embeds = None

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        if deepstack_input_embeds is not None:
2648
2649
2650
2651
2652
2653
            self._set_deepstack_input_embeds(deepstack_input_embeds)

        return inputs_embeds

    def forward(
        self,
2654
        input_ids: torch.Tensor | None,
2655
        positions: torch.Tensor,
2656
2657
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
2658
        **kwargs: object,
2659
    ) -> torch.Tensor | IntermediateTensors:
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
        """Run forward pass for Qwen3VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen3VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
            intermediate_tensors: Intermediate tensors from previous pipeline
                stages.
            inputs_embeds: Pre-computed input embeddings.
            **kwargs: Additional keyword arguments including:
                - pixel_values: Pixel values to be fed to a model.
                    `None` if no images are passed.
                - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
                    LLM. `None` if no images are passed.
                - pixel_values_videos: Pixel values of videos to be fed to a
                    model. `None` if no videos are passed.
                - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
                    LLM. `None` if no videos are passed.
2682
2683
2684
2685
2686
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

2687
        if inputs_embeds is not None and get_pp_group().is_first_rank:
2688
            deepstack_input_embeds = self._get_deepstack_input_embeds(
2689
2690
                inputs_embeds.size(0)
            )
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
2711
    ) -> torch.Tensor | None:
2712
        return self.language_model.compute_logits(hidden_states)
2713

2714
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
2715
        loader = AutoWeightsLoader(self)
2716
2717
2718
2719
2720
2721
2722
2723
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
2724
            connector=["visual.merger", "visual.deepstack_merger_list"],
2725
            tower_model="visual.",
2726
        )
2727

2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2
2746
2747
2748
2749
2750


@lru_cache
def _cached_tensor(x, device) -> torch.Tensor:
    return torch.tensor(x, device=device)