qwen2_vl.py 55.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32

33
import numpy as np
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
37
from einops import rearrange
38
from transformers import BatchFeature
39
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
40
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
41
42
43
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
44
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
45
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
46

47
from vllm.attention.backends.registry import AttentionBackendEnum
48
49
50
from vllm.attention.layer import (
    maybe_get_vit_flash_attn_backend,
)
51
from vllm.config import VllmConfig
52
from vllm.config.multimodal import BaseDummyOptions
53
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
54
55
56
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
57
from vllm.model_executor.layers.conv import Conv3dLayer
58
59
60
61
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.rotary_embedding import get_rope
64
from vllm.model_executor.layers.rotary_embedding.common import (
65
    apply_rotary_emb_torch,
66
67
    dispatch_rotary_emb_function,
)
68
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
69
from vllm.model_executor.models.module_mapping import MultiModelKeys
70
from vllm.multimodal import MULTIMODAL_REGISTRY
71
72
73
74
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
75
    MultiModalFeatureSpec,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
93
from vllm.multimodal.profiling import BaseDummyInputsBuilder
94
from vllm.sequence import IntermediateTensors
95
from vllm.tokenizers import TokenizerLike
96
from vllm.utils.tensor_schema import TensorSchema, TensorShape
97

98
99
100
101
102
103
104
105
106
107
108
109
110
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
111
112
113
114
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
115

116
117
logger = init_logger(__name__)

118
# For profile run
119
_MAX_FRAMES_PER_VIDEO = 14
120

121
122
123
# === Vision Inputs === #


124
class Qwen2VLImagePixelInputs(TensorSchema):
125
    """
126
127
128
129
130
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
131

132
    Historical context:
133
        - pixel_values shape: (num_patches, num_channels * patch_size *
134
135
136
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
137
    """
138

139
    type: Literal["pixel_values"]
140

141
142
143
144
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
145

146
147
148
149
150
151
152
153
154
155
156
157
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
158

159
160
161
162
163
164
165
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
166
    """
167

168
    type: Literal["image_embeds"]
169

170
171
172
173
174
175
176
177
178
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
179
180


181
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
182
183


184
185
186
187
188
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
189
        - ctps: Number of channels * temporal_patch_size * patch_size *
190
191
          patch_size
        - nv: Number of videos
192

193
    Historical context:
194
        - pixel_values_videos shape: (num_patches, num_channels *
195
196
197
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
198
    """
199

200
    type: Literal["pixel_values_videos"]
201

202
203
204
205
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
206

207
208
209
210
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
211
212


213
214
215
216
217
218
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
219

220
221
222
223
224
225
226
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
227
    """
228

229
    type: Literal["video_embeds"]
230

231
232
233
234
235
236
237
238
239
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
240
241


242
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
243

244
245
246
247
248
249
250
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
251
        hidden_features: int,
252
        act_layer: type[nn.Module] = QuickGELU,
253
        quant_config: QuantizationConfig | None = None,
254
        prefix: str = "",
255
        use_data_parallel: bool = False,
256
257
    ):
        super().__init__()
258
259
260
261
262
263
264
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
265
        self.act = act_layer()
266
267
268
269
270
271
272
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
273
274
275
276
277
278
279
280

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


281
282
def apply_rotary_pos_emb_vision(
    t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
283
) -> torch.Tensor:
284
285
    rotary_emb_function = dispatch_rotary_emb_function(
        default=partial(apply_rotary_emb_torch, is_neox_style=True)
286
    )
287
    output = rotary_emb_function(t, cos, sin).type_as(t)
288
289
290
291
292
293
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
294
295
296
        embed_dim: int,
        num_heads: int,
        projection_size: int,
297
        quant_config: QuantizationConfig | None = None,
298
        prefix: str = "",
299
        use_data_parallel: bool = False,
300
        attn_backend_override: AttentionBackendEnum | None = None,
301
302
303
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
304
305
306
307
308
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
309
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
310
        self.hidden_size_per_attention_head = dist_utils.divide(
311
312
            projection_size, num_heads
        )
313
        self.num_attention_heads_per_partition = dist_utils.divide(
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
331
332

        # Detect attention implementation.
333
334
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
335
            dtype=torch.get_default_dtype(),
336
            attn_backend_override=attn_backend_override,
337
        )
338

339
340
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
341
                self.attn_backend,
342
                attn_backend_override=attn_backend_override,
343
            )
344
        )
345

346
        if self.attn_backend not in {
347
348
349
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
350
351
        }:
            raise RuntimeError(
352
353
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
354

355
        self.is_flash_attn_backend = self.attn_backend in {
356
357
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
358
        }
359

360
361
362
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
363
364
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)
365
366
367
368

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

369
370
371
372
373
374
375
376
377
        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

378
        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
379
380
381
382
383
384
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
385
386
387
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

388
    def forward(
389
390
391
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
392
393
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
394
        max_seqlen: int | None = None,  # Only used for Flash Attention
395
    ) -> torch.Tensor:
396
397
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
398

399
400
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
401
402
        batch_size = q.shape[1]

403
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
404
405
406
407
408
409
410

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(
            qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
411

412
        if self.is_flash_attn_backend:
413
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
414

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
430
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
431
            # Execute attention entry by entry for speed & less VRAM.
432
433
434
435
436
437
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
438
            outputs = []
439
440
441
442
443
444

            lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
            q_chunks = torch.split(q, lens, dim=1)
            k_chunks = torch.split(k, lens, dim=1)
            v_chunks = torch.split(v, lens, dim=1)
            for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
445
446
447
448
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
449
450
451
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
452
453
454
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
455
456
457
458
459
460
461
462
463
464
465

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
466
        act_layer: type[nn.Module] = QuickGELU,
467
468
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
469
        prefix: str = "",
470
        use_data_parallel: bool = False,
471
        attn_backend_override: AttentionBackendEnum | None = None,
472
473
474
475
476
477
478
479
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

480
481
482
483
484
485
486
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
487
            attn_backend_override=attn_backend_override,
488
489
490
491
492
493
494
495
496
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
497

498
    def forward(
499
500
501
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
502
503
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
504
        max_seqlen: int | None = None,  # Only used for Flash Attention
505
506
507
508
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
509
510
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
511
512
513
            max_seqlen=max_seqlen,
        )

514
515
516
517
518
519
520
521
522
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
523
        in_channels: int = 3,
524
525
526
527
528
529
530
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

531
        kernel_size = (temporal_patch_size, patch_size, patch_size)
532
533
        self.proj = Conv3dLayer(
            in_channels,
534
            embed_dim,
535
536
            kernel_size=kernel_size,
            stride=kernel_size,
537
538
            bias=False,
        )
539
540

    def forward(self, x: torch.Tensor) -> torch.Tensor:
541
542
543
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
544
545
546
547
548
549
550
551
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
552
        norm_layer: Callable[[int], nn.Module] | None = None,
553
        spatial_merge_size: int = 2,
554
        quant_config: QuantizationConfig | None = None,
555
        prefix: str = "",
556
        use_data_parallel: bool = False,
557
558
559
560
561
562
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
601
        quant_config: QuantizationConfig | None = None,
602
        prefix: str = "",
603
        use_data_parallel: bool = False,
604
        attn_backend_override: AttentionBackendEnum | None = None,
605
606
607
    ) -> None:
        super().__init__()

608
609
610
611
612
613
614
615
616
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
617

618
619
620
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

621
        self.spatial_merge_size = spatial_merge_size
622
623
        self.num_heads = num_heads
        self.embed_dim = embed_dim
624
625
626
627

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
628
            in_channels=in_channels,
629
630
631
632
633
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
634
635
636
637
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
638
            rope_parameters={"partial_rotary_factor": 0.5},
639
        )
640

641
642
643
644
645
646
647
648
649
650
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
651
                    attn_backend_override=attn_backend_override,
652
653
654
655
                )
                for layer_idx in range(depth)
            ]
        )
656
657
658
659
660
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
661
            prefix=f"{prefix}.merger",
662
            use_data_parallel=use_data_parallel,
663
        )
664
        self.attn_backend = get_vit_attn_backend(
665
666
667
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
668
        )
669
670
671

    @property
    def dtype(self) -> torch.dtype:
672
        return self.patch_embed.proj.weight.dtype
673
674
675

    @property
    def device(self) -> torch.device:
676
        return self.patch_embed.proj.weight.device
677

678
679
680
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
681
        pos_ids = []
682
        max_grid_size = 0
683
684
685
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
707
            max_grid_size = max(max_grid_size, h, w)
708
        pos_ids = torch.cat(pos_ids, dim=0)
709
710
711
712

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

713
714
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
715
        return cos_combined, sin_combined
716

717
718
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
719
720
721
722
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
723
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
724
        return max_seqlen
725

726
727
728
    def forward(
        self,
        x: torch.Tensor,
729
        grid_thw: torch.Tensor | list[list[int]],
730
731
732
733
734
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

735
736
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
737
            grid_thw = np.array(grid_thw, dtype=np.int32)
738
739
        else:
            grid_thw_list = grid_thw.tolist()
740
            grid_thw = grid_thw.numpy()
741

742
        # compute position embedding
743
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
744
745

        # compute cu_seqlens
746
747
748
749
750
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
751
752
753

        # transformers
        x = x.unsqueeze(1)
754

755
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
756
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
757
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
758
        for blk in self.blocks:
759
760
761
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
762
763
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
764
765
                max_seqlen=max_seqlen,
            )
766
767
768

        # adapter
        x = self.merger(x)
769

770
771
        return x

772
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
773
774
775
776
777
778
779
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
780
        loaded_params: set[str] = set()
781
782

        for name, loaded_weight in weights:
783
            for param_name, weight_name, shard_id in stacked_params_mapping:
784
785
786
787
788
789
790
791
792
793
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
794
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
795
796
797
798
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

799

800
def _create_qwen2vl_field_factory(
801
    spatial_merge_size: int,
802
803
) -> Callable[
    [Mapping[str, torch.Tensor]],
804
    Mapping[str, MultiModalFieldConfig],
805
806
807
808
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
809
810
811
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
812
813
814

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
815
816
817
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
818
819
820

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
821
822
                "image", image_pixel_grid_sizes
            ),
823
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
824
825
                "image", image_embed_grid_sizes
            ),
826
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
827
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
828
829
                "video", video_grid_sizes
            ),
830
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
831
832
                "video", video_embed_grid_sizes
            ),
833
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
834
835
836
        )

    return _qwen2vl_field_config
837

838

Roger Wang's avatar
Roger Wang committed
839
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
840
841
842
843
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

844
845
    def _parse_image_data(
        self,
846
847
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
848
        if isinstance(data, dict):
849
850
851
852
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
853
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
854
            )
855
856
857
858

        return super()._parse_image_data(data)

    def _parse_video_data(
859
        self,
860
861
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
862
        if isinstance(data, dict):
863
864
865
866
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
867
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
868
            )
869
870
871
872

        return super()._parse_video_data(data)


873
874
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
875
876
        return self.ctx.get_hf_config(Qwen2VLConfig)

877
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
878
879
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
880
            use_fast=kwargs.pop("use_fast", True),
881
882
883
            **kwargs,
        )

884
885
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
886

887
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
888
889
        return {"image": None, "video": None}

890
891
892
893
894
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
895
896
897
898
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

899
900
901
902
903
904
905
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
906
        image_processor: Qwen2VLImageProcessor | None,
907
    ) -> tuple[ImageSize, int]:
908
909
910
911
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
912
        vision_config = hf_config.vision_config
913
914
915
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
916

917
918
919
920
921
922
923
924
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
925
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
926
        else:
927
            preprocessed_size = ImageSize(width=image_width, height=image_height)
928

929
930
931
932
933
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
934
935
936
937
938
939
940
941
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

942
    def get_num_image_tokens(
943
944
945
946
        self,
        *,
        image_width: int,
        image_height: int,
947
        image_processor: Qwen2VLImageProcessor | None,
948
949
950
951
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
952
            num_frames=1,
953
            image_processor=image_processor,
954
955
956
        )
        return num_image_tokens

957
    def get_num_video_tokens(
958
959
960
961
962
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
963
        image_processor: Qwen2VLImageProcessor | None,
964
965
966
967
968
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
969
            image_processor=image_processor,
970
971
972
        )
        return num_video_tokens

973
    def get_image_size_with_most_features(self) -> ImageSize:
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        # NOTE: Simply processing a huge size with _get_vision_info might not give a
        # size that maximizes the number of featrues, i.e., the number of (merged)
        # patches. This is because the number of patches limits the allowed aspect
        # ratios. For example, suppose the maximum number of patches is 1280. A square
        # image cannot be broken down into 1280 patches, so feeding a giant square image
        # into _get_vision_info will not yield a size that maximizes the number of
        # patches. Therefore, we directly factorize the maximum number of patches into
        # height and width. The tricky part is to avoid extreme aspect ratios (>200 for
        # qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
        # patches and retry. This is safe because the processor does not accept extreme
        # aspect ratios, so there is no valid post-resize image with the number of
        # patches that yields extreme aspect ratios.

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        image_processor = self.get_image_processor()
        max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"]
        unit = patch_size * merge_size
        max_seq_len = max_pixels // (unit * unit)

        def closest_factor_pair(n: int) -> tuple[int, int]:
            # left <= right
            for d in range(math.isqrt(n), 0, -1):
                if n % d == 0:
                    return d, n // d
            return 1, n

        height_factor, width_factor = 1, max_seq_len
        for seq_len in range(max_seq_len, 0, -1):
            height_factor, width_factor = closest_factor_pair(seq_len)
            if width_factor / height_factor <= 200:
                break

        return ImageSize(width=unit * width_factor, height=unit * height_factor)
1010

1011
1012
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1013

1014
        return self.get_num_image_tokens(
1015
1016
            image_width=target_width,
            image_height=target_height,
1017
            image_processor=None,
1018
        )
1019

1020
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1021
        target_width, target_height = self.get_image_size_with_most_features()
1022

1023
        num_frames = start_num_frames
1024
1025
1026

        while True:
            next_num_frames = num_frames + 1
1027
            next_max_tokens = self.get_num_video_tokens(
1028
1029
1030
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1031
                image_processor=None,
1032
            )
1033

1034
            if next_max_tokens > max_tokens:
1035
1036
1037
1038
1039
1040
                break

            num_frames = next_num_frames

        return num_frames

1041
1042
1043
1044
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1045
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1046
1047
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1048

1049
        max_total_frames = self._get_max_video_frames(seq_len)
1050
1051
1052
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1053

1054
        return max(max_frames_per_video, 1)
1055

1056
1057
1058
1059
1060
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1061
        target_width, target_height = self.get_image_size_with_most_features()
1062

1063
        return self.get_num_video_tokens(
1064
1065
            image_width=target_width,
            image_height=target_height,
1066
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1067
            image_processor=None,
1068
1069
        )

1070
1071

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1072
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1073
1074
1075
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1076
        hf_processor = self.info.get_hf_processor()
1077
1078
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1079

1080
1081
1082
1083
1084
1085
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1086
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1087
1088
1089
1090
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1091
1092
1093
1094
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1095

1096
1097
1098
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1099
        return {
1100
1101
1102
1103
1104
1105
1106
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1107
1108
                width=target_width,
                height=target_height,
1109
                num_frames=target_num_frames,
1110
                num_videos=num_videos,
1111
                overrides=video_overrides,
1112
            ),
1113
1114
        }

1115

1116
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1117
    def _get_data_parser(self) -> MultiModalDataParser:
1118
        return Qwen2VLMultiModalDataParser(
1119
1120
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1121

1122
    def _get_prompt_updates(
1123
1124
        self,
        mm_items: MultiModalDataItems,
1125
        hf_processor_mm_kwargs: Mapping[str, Any],
1126
        out_mm_kwargs: MultiModalKwargsItems,
1127
    ) -> Sequence[PromptUpdate]:
1128
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1129
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1130
1131
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1132
1133

        placeholder = {
1134
1135
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1136
        }
1137

1138
1139
1140
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1141
1142
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1143
1144
            assert isinstance(grid_thw, torch.Tensor)

1145
1146
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1147
1148
1149
1150

        return [
            PromptReplacement(
                modality=modality,
1151
                target=[placeholder[modality]],
1152
1153
1154
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1155
        ]
1156

1157
1158
1159
1160
1161
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1162
        return _create_qwen2vl_field_factory(
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1175
    # To ensure correct weight loading and mapping.
1176
1177
1178
1179
1180
1181
1182
1183
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1184
1185
        }
    )
1186

1187
1188
    supports_encoder_tp_data = True

1189
1190
1191
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1192
        mm_features: list[MultiModalFeatureSpec],
1193
    ) -> tuple[torch.Tensor, int]:
1194
1195
1196
1197
1198
1199
1200
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1201

1202
        hf_config = self.config
1203
1204
1205
1206
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1207
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1208
1209
1210

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1211
1212
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1239
                t, h, w = image_grid_thw[image_index]
1240
1241
1242
1243
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1244
                t, h, w = video_grid_thw[video_index]
1245
1246
1247
1248
1249
1250
1251
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1252
1253
1254
1255
1256
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1257
1258
            text_len = ed - st

1259
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1260
            llm_pos_ids_list.append(
1261
1262
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1263

1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1275

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1288
            llm_pos_ids_list.append(
1289
1290
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1291
1292
1293
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1294
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1295
1296
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1297
1298
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1299
1300

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1301
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1302
1303
1304

        return llm_positions, mrope_position_delta

1305
    @classmethod
1306
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1307
1308
1309
1310
1311
1312
1313
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1314
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1315
        super().__init__()
1316
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1317
1318
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1319

1320
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1321
1322
1323
        self.config = config
        self.multimodal_config = multimodal_config

1324
1325
1326
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1327
1328
1329
1330
1331
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1332
1333
1334
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1335
                quant_config=quant_config,
1336
                prefix=maybe_prefix(prefix, "visual"),
1337
                use_data_parallel=self.use_data_parallel,
1338
                attn_backend_override=attn_backend_override,
1339
1340
1341
            )
        else:
            self.visual = None
1342

1343
1344
1345
1346
1347
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1348

1349
        self.make_empty_intermediate_tensors = (
1350
1351
            self.language_model.make_empty_intermediate_tensors
        )
1352
1353

    def _parse_and_validate_image_input(
1354
        self, **kwargs: object
1355
    ) -> Qwen2VLImageInputs | None:
1356
        pixel_values = kwargs.pop("pixel_values", None)
1357
        image_embeds = kwargs.pop("image_embeds", None)
1358
1359
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1360
        if pixel_values is None and image_embeds is None:
1361
1362
            return None

1363
        if pixel_values is not None:
1364
1365
1366
1367
1368
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1369
1370

        if image_embeds is not None:
1371
1372
1373
1374
1375
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1376
1377

    def _parse_and_validate_video_input(
1378
        self, **kwargs: object
1379
    ) -> Qwen2VLVideoInputs | None:
1380
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1381
        video_embeds = kwargs.pop("video_embeds", None)
1382
1383
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1384
        if pixel_values_videos is None and video_embeds is None:
1385
1386
            return None

1387
1388
1389
1390
1391
1392
1393
1394
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1395
1396
1397
1398
1399
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1400

1401
    def _process_image_input(
1402
1403
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1404
1405
1406
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1407
        if image_input["type"] == "image_embeds":
1408
            image_embeds = image_input["image_embeds"]
1409
        else:
1410
            pixel_values = image_input["pixel_values"]
1411
1412

            if self.use_data_parallel:
1413
                return run_dp_sharded_mrope_vision_model(
1414
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1415
                )
1416
            else:
1417
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1418
1419
1420

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1421
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1422
        return image_embeds.split(sizes)
1423
1424

    def _process_video_input(
1425
1426
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1427
1428
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1429

1430
        if video_input["type"] == "video_embeds":
1431
            video_embeds = video_input["video_embeds"]
1432
        else:
1433
            pixel_values_videos = video_input["pixel_values_videos"]
1434
            if self.use_data_parallel:
1435
                return run_dp_sharded_mrope_vision_model(
1436
1437
1438
1439
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
1440
                )
1441
            else:
1442
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1443

1444
1445
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1446
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1447
        return video_embeds.split(sizes)
1448
1449
1450
1451
1452
1453
1454

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1465
1466

        return modalities
1467

1468
1469
1470
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1471
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1472
1473
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1474
            return []
1475

1476
1477
1478
1479
1480
1481
1482
1483
1484
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1485
1486
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1487
1488
1489
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1490
                multimodal_embeddings += tuple(video_embeddings)
1491
1492
1493

        return multimodal_embeddings

1494
1495
1496
1497
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1498
1499
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1500
        **kwargs: object,
1501
    ) -> torch.Tensor | IntermediateTensors:
1502
1503
1504
1505
1506
1507
1508
1509
1510
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1511
1512
1513
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1514
        """
1515

1516
        if intermediate_tensors is not None:
1517
            inputs_embeds = None
1518

1519
        hidden_states = self.language_model.model(
1520
1521
            input_ids=input_ids,
            positions=positions,
1522
            intermediate_tensors=intermediate_tensors,
1523
1524
1525
1526
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1527
1528
1529
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1530
    ) -> torch.Tensor | None:
1531
        return self.language_model.compute_logits(hidden_states)
1532

1533
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1534
1535
1536
1537
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1538
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1539
1540
1541
1542
1543
1544
1545

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1546
1547
1548
            connector="visual.merger.",
            tower_model="visual.",
        )
1549
1550
1551
1552
1553
1554
1555
1556
1557


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1558
        size: dict[str, int] | None = None,
1559
1560
1561
1562
1563
1564
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1565
                "longest_edge": size["max_pixels"],
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
1576
        tokenizer: TokenizerLike,
1577
1578
1579
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1580
1581
1582
1583
1584
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1585
1586
            **kwargs,
        )
1587
1588
1589
1590
1591


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1592
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1604
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1605
1606


1607
1608
1609
1610
1611
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1612
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1613
1614
1615
1616
1617
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1618

1619
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1620
1621
1622
1623
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1624
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)