qwen3_vl.py 63.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
26

27
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
28
from functools import lru_cache, partial
29
from itertools import islice
30
from typing import Any
31
32
33
34
35

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
36
from transformers import BatchFeature
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
38
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
39
40
41
    smart_resize as image_smart_resize,
)
from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor
42
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
43
44
45
    Qwen3VLConfig,
    Qwen3VLVisionConfig,
)
46
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
47
48
    smart_resize as video_smart_resize,
)
49
50
from transformers.video_utils import VideoMetadata

51
from vllm.attention.backends.registry import AttentionBackendEnum
52
53
54
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
55
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
56
57
58
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
59
from vllm.model_executor.layers.conv import Conv3dLayer
60
61
62
63
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
64
65
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
66
from vllm.model_executor.layers.rotary_embedding import get_rope
67
68
69
70
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
71
72
from vllm.multimodal.inputs import (
    MultiModalDataDict,
73
    MultiModalFeatureSpec,
74
75
76
77
78
79
80
81
82
83
84
85
    MultiModalFieldConfig,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
86
87
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
88
from vllm.utils.collection_utils import is_list_of
89

90
91
from .interfaces import (
    MultiModalEmbeddings,
92
    SupportsEagle3,
93
    SupportsLoRA,
94
    SupportsMRoPE,
95
96
97
98
99
100
101
102
103
104
105
106
    SupportsMultiModal,
    SupportsPP,
)
from .qwen2_5_vl import (
    Qwen2_5_VisionAttention,
    Qwen2_5_VLImageEmbeddingInputs,
    Qwen2_5_VLImageInputs,
    Qwen2_5_VLImagePixelInputs,
    Qwen2_5_VLVideoEmbeddingInputs,
    Qwen2_5_VLVideoInputs,
    Qwen2_5_VLVideoPixelInputs,
)
107
108
from .qwen2_vl import Qwen2VLProcessingInfo
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
109
110
111
112
113
114
115
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    WeightsMapper,
    _merge_multimodal_embeddings,
    maybe_prefix,
)
116
117
118
119
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
120
121
122

logger = init_logger(__name__)

123
124
125
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

class Qwen3_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        hidden_size: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
141
142
        self.proj = Conv3dLayer(
            in_channels,
143
            hidden_size,
144
145
            kernel_size=kernel_size,
            stride=kernel_size,
146
147
            bias=True,
        )
148
149

    def forward(self, x: torch.Tensor) -> torch.Tensor:
150
151
152
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.hidden_size)
153
154
155
156
        return x


class Qwen3_VisionMLP(nn.Module):
157
158
159
160
161
162
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
163
        quant_config: QuantizationConfig | None = None,
164
165
166
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
167
        super().__init__()
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        self.linear_fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
        self.linear_fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        self.act_fn = act_fn

    def forward(self, x: torch.Tensor):
        mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return mlp_output


class Qwen3_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
200
201
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
202
        prefix: str = "",
203
        use_data_parallel: bool = False,
204
        attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
205
        use_upstream_fa: bool = False,
206
207
208
209
210
211
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
212
213
214
215
216
217
        self.attn = Qwen2_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
218
219
            use_data_parallel=use_data_parallel,
            attn_backend=attn_backend,
220
221
222
223
224
225
226
227
228
229
230
            use_upstream_fa=use_upstream_fa,
        )
        self.mlp = Qwen3_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_fn=act_fn,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
231
232

    def forward(
233
234
235
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
236
237
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
238
        max_seqlen: torch.Tensor,  # Only used for Flash Attention
239
    ) -> torch.Tensor:
240
241
242
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
243
244
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
245
246
            max_seqlen=max_seqlen,
        )
247
248
249
250
251
252
253
254
255
256

        x = x + self.mlp(self.norm2(x))
        return x


class Qwen3_VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
257
        norm_layer: Callable[[int], nn.Module] | None = None,
258
259
        spatial_merge_size: int = 2,
        use_postshuffle_norm: bool = False,
260
        quant_config: QuantizationConfig | None = None,
261
        prefix: str = "",
262
        use_data_parallel: bool = False,
263
264
265
266
267
268
269
270
271
272
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)

        self.use_postshuffle_norm = use_postshuffle_norm
        if self.use_postshuffle_norm:
            context_dim = self.hidden_size

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
273
        self.norm = norm_layer(context_dim)
274
275
276
277
278
279
280
281
        self.linear_fc1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc1",
            disable_tp=use_data_parallel,
        )
282
        self.act_fn = nn.GELU()
283
284
285
286
287
288
289
290
        self.linear_fc2 = RowParallelLinear(
            self.hidden_size,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_fc2",
            disable_tp=use_data_parallel,
        )
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_postshuffle_norm:
            x = self.norm(x.view(-1, self.hidden_size))
        else:
            x = self.norm(x).view(-1, self.hidden_size)

        x_parallel, _ = self.linear_fc1(x)
        x_parallel = self.act_fn(x_parallel)
        out, _ = self.linear_fc2(x_parallel)
        return out


class Qwen3_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen3VLVisionConfig,
        norm_eps: float = 1e-6,
309
        quant_config: QuantizationConfig | None = None,
310
        prefix: str = "",
311
        use_data_parallel: bool = False,
312
        attn_backend_override: AttentionBackendEnum | None = None,
313
314
315
316
317
318
319
320
321
322
    ) -> None:
        super().__init__()
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.num_position_embeddings = vision_config.num_position_embeddings
        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.spatial_merge_unit = self.spatial_merge_size**2
        self.temporal_patch_size = vision_config.temporal_patch_size
        self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
323
        self.use_data_parallel = use_data_parallel
324
        self.num_grid_per_side = int(self.num_position_embeddings**0.5)
325
326
327

        # NOTE: This is used for creating empty tensor for all_gather for
        # DP ViT. Here out_hidden_size is enlarged due to deepstack
328
329
330
        self.out_hidden_size = vision_config.out_hidden_size * (
            1 + len(self.deepstack_visual_indexes)
        )
331
332
333
334
335
336
337
338

        self.patch_embed = Qwen3_VisionPatchEmbed(
            patch_size=self.patch_size,
            temporal_patch_size=self.temporal_patch_size,
            in_channels=vision_config.in_channels,
            hidden_size=self.hidden_size,
        )

339
        self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
340
341
342

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
343
344
345
346
347
348
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            rotary_dim=head_dim // 2,
            max_position=8192,
            is_neox_style=True,
        )
349
350
351
352
353
354
355
356

        self.merger = Qwen3_VisionPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=self.hidden_size,
            norm_layer=norm_layer,
            spatial_merge_size=self.spatial_merge_size,
            quant_config=quant_config,
            prefix=f"{prefix}.merger",
357
            use_data_parallel=use_data_parallel,
358
359
        )

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        self.deepstack_merger_list = nn.ModuleList(
            [
                Qwen3_VisionPatchMerger(
                    d_model=vision_config.out_hidden_size,
                    context_dim=self.hidden_size,
                    spatial_merge_size=self.spatial_merge_size,
                    use_postshuffle_norm=True,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                )
                for layer_idx in range(len(self.deepstack_visual_indexes))
            ]
        )
375
376

        self.attn_backend = get_vit_attn_backend(
377
378
379
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
380
        )
381
        use_upstream_fa = False
382
        if (
383
384
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
385
386
            and check_upstream_fa_availability(torch.get_default_dtype())
        ):
387
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
388
389
390
            use_upstream_fa = True

        if self.attn_backend not in {
391
392
393
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
394
395
        }:
            raise RuntimeError(
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
                f"Qwen3-VL does not support {self.attn_backend} backend now."
            )
        self.blocks = nn.ModuleList(
            [
                Qwen3_VisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.intermediate_size,
                    act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                    attn_backend=self.attn_backend,
                    use_upstream_fa=use_upstream_fa,
                )
                for layer_idx in range(vision_config.depth)
            ]
        )
415
416
417
418
419
420
421
422
423

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    @staticmethod
    @lru_cache(maxsize=1024)
    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
        h_div = h // spatial_merge_size
        w_div = w // spatial_merge_size
        hpos_ids = hpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
        wpos_ids = wpos_ids.reshape(
            h_div,
            spatial_merge_size,
            w_div,
            spatial_merge_size,
        )
        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))

451
452
    def rot_pos_emb(self, grid_thw: list[list[int]]):
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
453
454
455
456
457
458
        pos_ids = [
            self.rot_pos_ids(h, w, self.spatial_merge_size)
            if t == 1
            else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
            for t, h, w in grid_thw
        ]
459
        pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
460
461
462
463

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

464
465
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
466
467

        return cos_combined, sin_combined
468

469
    def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
470
471
472
        num_grid_per_side = self.num_grid_per_side
        m_size = self.spatial_merge_size
        hidden_dim = self.pos_embed.embedding_dim
473

474
        outputs = []
475
        for t, h, w in grid_thw:
476
477
478
479
480
481
            h_idxs = torch.linspace(
                0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
            )
            w_idxs = torch.linspace(
                0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
            )
482
483
484
485
486
487
488
489
490

            h_floor = h_idxs.to(torch.long)
            w_floor = w_idxs.to(torch.long)
            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

            dh = h_idxs - h_floor
            dw = w_idxs - w_floor

491
            # Create meshgrid view for all h, w vars
492
493
494
            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
495
496
497
498
499
500
501
502
503
504
505

            # original computation of weights
            # w00 = (1 - dh_grid) * (1 - dw_grid)
            # w01 = (1 - dh_grid) * dw_grid
            # w10 = dh_grid * (1 - dw_grid)
            # w11 = dh_grid * dw_grid
            # we reuse w11 here to avoid duplicate
            # dh_grid * dw_grid computation
            w11 = dh_grid * dw_grid
            w10 = dh_grid - w11
            w01 = dw_grid - w11
506
            w00 = 1 - dh_grid - w01
507

508
509
510
            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
            h_grid_idx = h_grid * num_grid_per_side
511

512
            indices = (h_grid_idx + w_grid).reshape(4, -1)
513
            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
514
            weights = weights.to(dtype=self.dtype)
515
516

            embeds = self.pos_embed(indices)
517
518
            embeds *= weights
            combined = embeds.sum(dim=0)
519

520
521
            combined = combined.reshape(
                h // m_size, m_size, w // m_size, m_size, hidden_dim
522
            )
523
524
            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
525
526
527
            outputs.append(repeated)

        return torch.cat(outputs, dim=0)
528
529
530
531

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
532
    ) -> torch.Tensor:
533
        max_seqlen = torch.zeros([], device=cu_seqlens.device)
534
        if (
535
536
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
537
        ):
538
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
539
        return max_seqlen
540
541
542
543

    def forward(
        self,
        x: torch.Tensor,
544
        grid_thw: torch.Tensor | list[list[int]],
545
    ) -> torch.Tensor:
546
        hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
547
548
        hidden_states = self.patch_embed(hidden_states)

549
550
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
551
            grid_thw = np.array(grid_thw, dtype=np.int32)
552
553
        else:
            grid_thw_list = grid_thw.tolist()
554
            grid_thw = grid_thw.numpy()
555
556

        pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
557
        hidden_states = hidden_states + pos_embeds
558
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
559

560
561
562
563
564
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
565
566

        hidden_states = hidden_states.unsqueeze(1)
567
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
568
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
569
570
571

        deepstack_feature_lists = []
        for layer_num, blk in enumerate(self.blocks):
572
573
574
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
575
576
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
577
578
                max_seqlen=max_seqlen,
            )
579
            if layer_num in self.deepstack_visual_indexes:
580
581
582
583
                deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
                deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](
                    hidden_states
                )
584
585
586
                deepstack_feature_lists.append(deepstack_feature)
        hidden_states = self.merger(hidden_states)
        hidden_states = torch.cat(
587
588
            [hidden_states] + deepstack_feature_lists, dim=1
        )  # [seq_len, hidden_size * (1 + depth_of_deepstack)]
589
590
        return hidden_states

591
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
592
593
594
595
596
597
598
599
600
601
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
602
            for param_name, weight_name, shard_id in stacked_params_mapping:
603
604
605
606
607
608
609
610
611
612
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
613
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Qwen3VLConfig)

    def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
        return self.ctx.get_hf_processor(
            Qwen3VLProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )

    def get_tokenizer(self):
        return self.ctx.tokenizer

633
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
634
635
636
637
638
639
640
641
642
643
644
645
        return self.get_hf_processor(**kwargs).image_processor

    def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 2,
        do_resize: bool = True,
646
        image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None,
647
    ) -> tuple[ImageSize, int]:
648
649
650
        if image_processor is None and num_frames > 1:
            image_processor = self.get_video_processor()
        elif image_processor is None:
651
652
            image_processor = self.get_image_processor()

653
654
        is_video = isinstance(image_processor, Qwen3VLVideoProcessor)

655
656
657
658
659
660
661
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size

        if do_resize:
662
663
664
665
            if is_video:
                smart_resize = video_smart_resize
                extra_kwargs = {
                    "num_frames": num_frames,
666
                    "temporal_factor": temporal_patch_size,
667
668
669
670
                }
            else:
                smart_resize = image_smart_resize
                extra_kwargs = {}
671
672
673
674
675
676
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.size["shortest_edge"],
                max_pixels=image_processor.size["longest_edge"],
677
                **extra_kwargs,
678
            )
679
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
680
        else:
681
            preprocessed_size = ImageSize(width=image_width, height=image_height)
682
683
684
685
686
687
688
689
690
691
692
693

        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

694
695
696
697
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int:
        return super()._get_max_video_frames(
            max_tokens, start_num_frames=start_num_frames
        )
698
699
700
701
702
703
704

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        return super().get_num_frames_with_most_features(
705
706
            seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO
        )
707
708
709
710
711
712
713
714
715
716

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        video_soft_tokens = self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
717
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
718
719
720
721
722
723
724
725
            image_processor=None,
        )

        # NOTE: By default in Qwen3-VL, one video token is converted to
        # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
        formatted_video_soft_tokens = video_soft_tokens * 12.5
        return int(formatted_video_soft_tokens)

726
727
728
    def _calculate_timestamps(
        self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
    ):
729
730
731
732
        if not isinstance(indices, list):
            indices = indices.tolist()
        if len(indices) % merge_size != 0:
            # don't update metadata's frames_indices directly
733
            indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size)
734
        timestamps = [idx / video_fps for idx in indices]
735
736
737
738
        timestamps = [
            (timestamps[i] + timestamps[i + merge_size - 1]) / 2
            for i in range(0, len(timestamps), merge_size)
        ]
739
740
741
        return timestamps

    def _get_video_second_idx(
742
743
744
        self,
        metadata: dict[str, Any],
        out_item: MultiModalKwargsItem,
745
746
        do_sample_frames: bool | None = None,
        sampled_fps: float | None = None,
747
    ) -> list[int]:
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        video_processor = self.get_video_processor()
        merge_size = video_processor.merge_size
        indices = metadata["frames_indices"]

        # metadata["fps"] refers to the true fps of the input video.
        video_fps = metadata["fps"]
        if do_sample_frames is None:
            do_sample_frames = metadata.get("do_sample_frames", False)

        # If video frames are sampled in HF processor (instead of vLLM
        # video loader), we need to re-calculate the indices from original
        # metadata.
        if do_sample_frames:
            # here video_fps is the fps of the sampled video, and
            # metadata["fps"] refers to the fps of the original video.
763
            sampled_fps = sampled_fps if sampled_fps else video_processor.fps
764
            total_num_frames = metadata["total_num_frames"]
765
            num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
766
            num_frames = min(
767
768
769
770
771
772
773
774
775
776
777
778
                min(
                    max(num_frames, video_processor.min_frames),
                    video_processor.max_frames,
                ),
                total_num_frames,
            )
            indices = (
                np.linspace(0, total_num_frames - 1, num_frames)
                .round()
                .astype(int)
                .tolist()
            )
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
        timestamps = self._calculate_timestamps(indices, video_fps, merge_size)
        return timestamps


class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        image_token = "<|vision_start|><|image_pad|><|vision_end|>"
        video_token = "<|vision_start|><|video_pad|><|vision_end|>"

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
797
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
798
799
800
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
801
802
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None
803

804
        target_width, target_height = self.info.get_image_size_with_most_features()
805
        target_num_frames = self.info.get_num_frames_with_most_features(
806
807
            seq_len, mm_counts
        )
808
809
810
811
812
813
814
815
816

        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            num_frames_override = video_overrides.num_frames
            if num_frames_override:
                if num_frames_override > target_num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
817
818
819
                        num_frames_override,
                        target_num_frames,
                    )
820
821
822
                if num_frames_override < 2:
                    logger.warning(
                        "video.num_frames override (%d) cannot be less "
823
824
825
                        "than 2, will be ignored",
                        num_frames_override,
                    )
826
827
828
                target_num_frames = min(target_num_frames, num_frames_override)
        target_num_frames = max(target_num_frames, 2)

829
830
831
832
833
834
        target_video_size, _ = self.info._get_vision_info(
            image_width=target_width,
            image_height=target_height,
            num_frames=target_num_frames,
            image_processor=self.info.get_video_processor(),
        )
835
836
837
838
839
840
841
842
843
844
        # NOTE: we need to do this check here since Qwen3-VL resizes video
        # frames depending on how many frames there are.
        width, height = target_video_size.width, target_video_size.height
        if video_overrides:
            assert isinstance(video_overrides, VideoDummyOptions)
            width_override = video_overrides.width
            if width_override:
                if width_override > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
845
846
847
848
                        "maximum width (%d), will be ignored",
                        width_override,
                        width,
                    )
849
850
851
852
853
854
855
                width = min(width, width_override)
            height_override = video_overrides.height
            if height_override:
                if height_override > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
856
857
858
                        height_override,
                        height,
                    )
859
                height = min(height, height_override)
860

861
        return {
862
863
864
865
866
867
868
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
869
870
                width=width,
                height=height,
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[VideoItem]:
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
                "frames_indices": [i for i in range(num_frames)],
                "video_backend": "opencv",
                "do_sample_frames": False,
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)
        return video_items


900
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(video_needs_metadata=True)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # Separate video processing from image processing. Because the videos
915
916
        # are processed into several image patches
        if videos := mm_data.pop("videos", []):
917
918
919
            video_grid_thw_lst = []
            pixel_values_videos_lst = []

920
            for item in videos:
921
922
923
924
925
926
927
928
929
930
931
932
933
934
                video_array, metadata = item

                # NOTE: @JJJYmmm new attr metadata.frames_indices indicates
                # the sampled frames indices of pre-sampled videos, which is
                # used to calculate the timestamps. Make sure that
                # do_sample_frames in mm_kwargs is false for presampled videos.

                # NOTE: a copy of is created to update do_sample_frames,
                # otherwise mm_hash for the object will be incorrect.
                video_mm_kwargs = dict(**mm_kwargs)
                if "do_sample_frames" not in video_mm_kwargs:
                    # qwen_vl_utils already has "do_sample_frames" in
                    # mm_kwargs, don't overwrite it.
                    video_mm_kwargs["do_sample_frames"] = metadata.get(
935
936
                        "do_sample_frames", False
                    )
937

938
939
940
                metadata = VideoMetadata(
                    **{k: metadata[k] for k in metadata if k != "do_sample_frames"}
                )
941
942
943
944
945
946
947
948
949
950
951
952

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
                video_mm_data["video_metadata"] = [[metadata]]

                video_outputs = super()._call_hf_processor(
                    prompt="<|vision_start|><|video_pad|><|vision_end|>",
                    mm_data=video_mm_data,
                    mm_kwargs=video_mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
                input_ids = video_outputs.pop("input_ids")
953
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
954
955
956
957
958
959
960
                prompt = prompt.replace(
                    "<|vision_start|><|video_pad|><|vision_end|>",
                    video_placeholder,
                    1,
                )

                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
961
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_grid_sizes = image_grid_thw.prod(-1)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
994
995
                "image", image_grid_sizes
            ),
996
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
997
998
                "image", image_grid_sizes
            ),
999
1000
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
1001
1002
                "video", video_grid_sizes
            ),
1003
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
1004
1005
                "video", video_grid_sizes
            ),
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1016
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
        tokenizer = self.info.get_tokenizer()
        hf_config = self.info.get_hf_config()

        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        vision_end_token_id = hf_config.vision_end_token_id

        merge_length = image_processor.merge_size**2

        def get_image_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_qwen3vl(item_idx: int):
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
            do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
            sampled_fps = hf_processor_mm_kwargs.get("fps")
            if is_list_of(sampled_fps, float):
                sampled_fps = sampled_fps[item_idx]
            timestamps = self.info._get_video_second_idx(
1045
1046
                metadata, out_item, do_sample_frames, sampled_fps
            )
1047
1048
1049

            assert len(timestamps) == grid_thw[0], (
                f"The timestamps length({len(timestamps)}) should be equal "
1050
1051
                f"video length ({grid_thw[0]})."
            )
1052
1053

            frames_idx_token = [
1054
                tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
1055
1056
1057
1058
1059
1060
                for curr_time in timestamps
            ]
            num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
            placeholder = []
            for frame_idx in frames_idx_token:
                placeholder.extend(frame_idx)
1061
1062
1063
1064
1065
1066
                placeholder.extend(
                    [vision_start_token_id]
                    + [video_token_id] * num_tokens_per_frame
                    + [vision_end_token_id]
                )
            return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_qwen3vl,
            ),
            # NOTE: We match string on purpose since searching sequence of
            # token ids takes more time.
            PromptReplacement(
                modality="video",
                target="<|vision_start|><|video_pad|><|vision_end|>",
                replacement=get_video_replacement_qwen3vl,
            ),
        ]


@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        # the same shape as input_embeds
1093
1094
1095
        "deepstack_input_embeds": 0,
    }
)
1096
1097
1098
1099
1100
class Qwen3LLMModel(Qwen3Model):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        if not get_pp_group().is_first_rank:
            assert self.start_layer >= len(
1101
1102
1103
1104
1105
                vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes
            ), (
                "start_layer should be greater than or equal to "
                "len(deepstack_visual_indexes)"
            )
1106
1107
1108
1109
1110

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1111
1112
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1113
        # args for deepstack
1114
1115
        deepstack_input_embeds: IntermediateTensors | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1116
1117
1118
1119
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1120
                hidden_states = self.embed_input_ids(input_ids)
1121
1122
1123
1124
1125
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
1126

1127
        aux_hidden_states = []
1128
1129
        for layer_idx, layer in islice(
            enumerate(self.layers), self.start_layer, self.end_layer
1130
        ):
1131
1132
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
1133

1134
1135
1136
1137
1138
1139
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

1140
1141
1142
1143
1144
1145
1146
            if deepstack_input_embeds is not None and layer_idx in range(
                0, len(deepstack_input_embeds)
            ):
                hidden_states = (
                    hidden_states
                    + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
                )
1147
1148

        if not get_pp_group().is_last_rank:
1149
1150
1151
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1152
        hidden_states, _ = self.norm(hidden_states, residual)
1153
1154
1155

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
        return hidden_states


class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super(Qwen3ForCausalLM, self).__init__()
        config = vllm_config.model_config.hf_config.text_config
        quant_config = vllm_config.quant_config

        self.config = config

        self.quant_config = quant_config
1168
1169
1170
        self.model = Qwen3LLMModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1171
1172
1173
1174
1175

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
1176
1177
1178
1179
1180
1181
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix="lm_head",
                )
1182
1183
1184
1185
1186
1187
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
1188
1189
            self.model.make_empty_intermediate_tensors
        )
1190
1191


1192
1193
1194
1195
1196
1197
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3VLMultiModalProcessor,
    info=Qwen3VLProcessingInfo,
    dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
1198
1199
1200
1201
1202
1203
    nn.Module,
    SupportsMultiModal,
    SupportsLoRA,
    SupportsPP,
    SupportsMRoPE,
    SupportsEagle3,
1204
):
1205
    merge_by_field_config = True
1206
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1207

1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
1219
1220
1221

    supports_encoder_tp_data = True

1222
1223
1224
1225
1226
1227
    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.visual.": "visual.",
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
1228
1229
        }
    )
1230
1231

    @classmethod
1232
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: Qwen3VLConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1248
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1249
1250
1251
        if not multimodal_config.get_limit_per_prompt(
            "image"
        ) and not multimodal_config.get_limit_per_prompt("video"):
1252
1253
            self.visual = None
        else:
1254
1255
1256
1257
1258
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1259
1260
1261
1262
1263
1264
            self.visual = Qwen3_VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "visual"),
                use_data_parallel=self.use_data_parallel,
1265
                attn_backend_override=attn_backend_override,
1266
            )
1267

1268
1269
1270
        self.language_model = Qwen3LLMForCausalLM(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
        )
1271
1272

        self.make_empty_intermediate_tensors = (
1273
1274
            self.language_model.make_empty_intermediate_tensors
        )
1275

1276
1277
1278
1279
1280
1281
        self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
        self.deepstack_num_level = (
            len(config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
1282
        # register buffer for deepstack
1283
1284
1285
1286
        if self.use_deepstack and self.visual is not None:
            self.deepstack_input_embeds = [
                torch.zeros(
                    vllm_config.scheduler_config.max_num_batched_tokens,
1287
1288
                    config.text_config.hidden_size,
                )
1289
1290
1291
1292
                for _ in range(self.deepstack_num_level)
            ]
        else:
            self.deepstack_input_embeds = None
1293
1294
        self.visual_dim = config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level
1295

1296
1297
1298
1299
1300
1301
1302
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.language_model.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.language_model.model.layers)
        return (2, num_layers // 2, num_layers - 3)

1303
    def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
1304
        # get deepstack_input_embeds from buffer, and clear the buffer
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
1315
1316
1317
1318
        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
1319
1320
1321
1322
1323
1324
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
1325
1326
1327
1328
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
1329
1330
                deepstack_input_embeds[idx]
            )
1331
1332
1333
1334
1335
1336
1337
1338

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

    def _parse_and_validate_image_input(
1339
        self, **kwargs: object
1340
    ) -> Qwen2_5_VLImageInputs | None:
1341
1342
1343
1344
1345
1346
1347
1348
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
1349
1350
1351
1352
1353
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1354
1355
1356
1357
1358

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
1359
1360
                image_grid_thw=image_grid_thw,
            )
1361
1362

    def _parse_and_validate_video_input(
1363
        self, **kwargs: object
1364
    ) -> Qwen2_5_VLVideoInputs | None:
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
1385
1386
                video_grid_thw=video_grid_thw,
            )
1387
1388

    def _process_image_input(
1389
1390
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1391
1392
1393
1394
1395
1396
1397
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1398
            if self.use_data_parallel:
1399
                return run_dp_sharded_mrope_vision_model(
1400
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1401
                )
1402
            else:
1403
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1404
1405
1406

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1407
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1408
1409
1410
        return image_embeds.split(sizes)

    def _process_video_input(
1411
1412
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1413
1414
1415
1416
1417
1418
1419
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1420
1421
                self.visual.dtype
            )
1422
            if self.use_data_parallel:
1423
                grid_thw_list = grid_thw.tolist()
1424
1425
1426
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1427
            else:
1428
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1429
1430
1431

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1432
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1433
1434
1435
1436
1437
        return video_embeds.split(sizes)

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}
        for input_key in kwargs:
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
1452
1453
        return mm_input_by_modality

1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
    def iter_mm_grid_hw(
        self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int]]:
        video_token_id = self.config.video_token_id
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, h // spatial_merge_size, w // spatial_merge_size
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                llm_grid_h = h // spatial_merge_size
                llm_grid_w = w // spatial_merge_size
                for _ in range(t):
                    offset = input_tokens.index(video_token_id, offset)
                    yield offset, llm_grid_h, llm_grid_w
                    offset += llm_grid_h * llm_grid_w
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

1476
    def get_mrope_input_positions(
1477
        self,
1478
        input_tokens: list[int],
1479
        mm_features: list[MultiModalFeatureSpec],
1480
    ) -> tuple[torch.Tensor, int]:
1481
        llm_pos_ids_list = []
1482
        st = 0
1483
1484
1485
1486
        for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
            input_tokens, mm_features
        ):
            text_len = offset - st
1487
1488
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
1489
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1490
1491
            )

1492
1493
1494
            grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
            llm_pos_ids_list.append(grid_indices + text_len + st_idx)
            st = offset + llm_grid_h * llm_grid_w
1495
1496
1497
1498
1499

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1500
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1501
1502
            )

1503
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
1504
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1505
        return torch.from_numpy(llm_positions), mrope_position_delta
1506

1507
1508
1509
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1510
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
1511
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
1524
1525
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += tuple(image_embeddings)
1526
1527
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
1528
                multimodal_embeddings += tuple(video_embeddings)
1529
1530
1531
        return multimodal_embeddings

    def _compute_deepstack_embeds(
1532
1533
1534
1535
1536
1537
        self,
        inputs_embeds: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings,
        is_multimodal: torch.Tensor,
    ) -> tuple[torch.Tensor, MultiModalEmbeddings]:
        visual_lens = [len(x) for x in multimodal_embeddings]
1538
1539
        multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

1540
1541
1542
1543
1544
1545
1546
1547
        (
            multimodal_embeddings_main,
            multimodal_embeddings_multiscale,
        ) = torch.split(
            multimodal_embeddings_cat,
            [self.visual_dim, self.multiscale_dim],
            dim=-1,
        )
1548

1549
1550
1551
        multimodal_embeddings = torch.split(
            multimodal_embeddings_main, visual_lens, dim=0
        )
1552
        multimodal_embeddings_multiscale = torch.split(
1553
1554
            multimodal_embeddings_multiscale, visual_lens, dim=0
        )
1555
1556

        deepstack_input_embeds = inputs_embeds.new_zeros(
1557
1558
            inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
        )
1559

1560
1561
1562
1563
        deepstack_input_embeds = _merge_multimodal_embeddings(
            inputs_embeds=deepstack_input_embeds,
            multimodal_embeddings=multimodal_embeddings_multiscale,
            is_multimodal=is_multimodal,
1564
1565
        )
        deepstack_input_embeds = deepstack_input_embeds.view(
1566
1567
            inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
        )
1568
        deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
1569

1570
1571
        return deepstack_input_embeds, multimodal_embeddings

1572
    def embed_input_ids(
1573
1574
        self,
        input_ids: torch.Tensor,
1575
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1576
        *,
1577
        is_multimodal: torch.Tensor | None = None,
1578
        handle_oov_mm_token: bool = False,
1579
    ) -> torch.Tensor:
1580
        inputs_embeds = self._embed_text_input_ids(
1581
            input_ids,
1582
            self.language_model.embed_input_ids,
1583
1584
1585
1586
1587
1588
1589
1590
1591
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
1592
                "`embed_input_ids` now requires `is_multimodal` arg, "
1593
                "please update your model runner according to "
1594
1595
                "https://github.com/vllm-project/vllm/pull/16229."
            )
1596
1597

        if self.use_deepstack:
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
            (
                deepstack_input_embeds,
                multimodal_embeddings,
            ) = self._compute_deepstack_embeds(
                inputs_embeds=inputs_embeds,
                multimodal_embeddings=multimodal_embeddings,
                is_multimodal=is_multimodal,
            )
        else:
            deepstack_input_embeds = None

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        if deepstack_input_embeds is not None:
1616
1617
1618
1619
1620
1621
1622
1623
            self._set_deepstack_input_embeds(deepstack_input_embeds)

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1624
1625
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1626
        **kwargs: object,
1627
    ) -> torch.Tensor | IntermediateTensors:
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
        """Run forward pass for Qwen3VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen3VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
            intermediate_tensors: Intermediate tensors from previous pipeline
                stages.
            inputs_embeds: Pre-computed input embeddings.
            **kwargs: Additional keyword arguments including:
                - pixel_values: Pixel values to be fed to a model.
                    `None` if no images are passed.
                - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
                    LLM. `None` if no images are passed.
                - pixel_values_videos: Pixel values of videos to be fed to a
                    model. `None` if no videos are passed.
                - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
                    LLM. `None` if no videos are passed.
1650
1651
1652
1653
1654
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

1655
1656
1657
1658
1659
        if (
            self.use_deepstack
            and inputs_embeds is not None
            and get_pp_group().is_first_rank
        ):
1660
            deepstack_input_embeds = self._get_deepstack_input_embeds(
1661
1662
                inputs_embeds.size(0)
            )
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1683
    ) -> torch.Tensor | None:
1684
        return self.language_model.compute_logits(hidden_states)
1685

1686
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1687
1688
1689
1690
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1691
1692
1693
1694
1695
1696
1697
1698
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1699
1700
            connector="visual.merger",
            tower_model="visual.",
1701
        )