qwen2_vl.py 53.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31

32
import numpy as np
33
34
35
import torch
import torch.nn as nn
import torch.nn.functional as F
36
from einops import rearrange
37
from transformers import BatchFeature
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.attention.backends.registry import AttentionBackendEnum
47
48
49
from vllm.attention.layer import (
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.conv import Conv3dLayer
57
58
59
60
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
from vllm.model_executor.layers.rotary_embedding import get_rope
63
from vllm.model_executor.layers.rotary_embedding.common import (
64
    apply_rotary_emb_torch,
65
66
    dispatch_rotary_emb_function,
)
67
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
68
from vllm.model_executor.models.module_mapping import MultiModelKeys
69
from vllm.multimodal import MULTIMODAL_REGISTRY
70
71
72
73
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
74
    MultiModalFeatureSpec,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
92
from vllm.multimodal.profiling import BaseDummyInputsBuilder
93
from vllm.sequence import IntermediateTensors
94
from vllm.tokenizers import TokenizerLike
95
from vllm.utils.tensor_schema import TensorSchema, TensorShape
96

97
98
99
100
101
102
103
104
105
106
107
108
109
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
110
111
112
113
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
114

115
116
logger = init_logger(__name__)

117
# For profile run
118
_MAX_FRAMES_PER_VIDEO = 14
119

120
121
122
# === Vision Inputs === #


123
class Qwen2VLImagePixelInputs(TensorSchema):
124
    """
125
126
127
128
129
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
130

131
    Historical context:
132
        - pixel_values shape: (num_patches, num_channels * patch_size *
133
134
135
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
136
    """
137

138
    type: Literal["pixel_values"]
139

140
141
142
143
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
144

145
146
147
148
149
150
151
152
153
154
155
156
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
157

158
159
160
161
162
163
164
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
165
    """
166

167
    type: Literal["image_embeds"]
168

169
170
171
172
173
174
175
176
177
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
178
179


180
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
181
182


183
184
185
186
187
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
188
        - ctps: Number of channels * temporal_patch_size * patch_size *
189
190
          patch_size
        - nv: Number of videos
191

192
    Historical context:
193
        - pixel_values_videos shape: (num_patches, num_channels *
194
195
196
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
197
    """
198

199
    type: Literal["pixel_values_videos"]
200

201
202
203
204
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
205

206
207
208
209
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
210
211


212
213
214
215
216
217
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
218

219
220
221
222
223
224
225
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
226
    """
227

228
    type: Literal["video_embeds"]
229

230
231
232
233
234
235
236
237
238
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
239
240


241
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
242

243
244
245
246
247
248
249
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
250
        hidden_features: int,
251
        act_layer: type[nn.Module] = QuickGELU,
252
        quant_config: QuantizationConfig | None = None,
253
        prefix: str = "",
254
        use_data_parallel: bool = False,
255
256
    ):
        super().__init__()
257
258
259
260
261
262
263
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
264
        self.act = act_layer()
265
266
267
268
269
270
271
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
272
273
274
275
276
277
278
279

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


280
281
def apply_rotary_pos_emb_vision(
    t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
282
) -> torch.Tensor:
283
284
    rotary_emb_function = dispatch_rotary_emb_function(
        default=partial(apply_rotary_emb_torch, is_neox_style=True)
285
    )
286
    output = rotary_emb_function(t, cos, sin).type_as(t)
287
288
289
290
291
292
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
293
294
295
        embed_dim: int,
        num_heads: int,
        projection_size: int,
296
        quant_config: QuantizationConfig | None = None,
297
        prefix: str = "",
298
        use_data_parallel: bool = False,
299
        attn_backend_override: AttentionBackendEnum | None = None,
300
301
302
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
303
304
305
306
307
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
308
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
309
        self.hidden_size_per_attention_head = dist_utils.divide(
310
311
            projection_size, num_heads
        )
312
        self.num_attention_heads_per_partition = dist_utils.divide(
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
330
331

        # Detect attention implementation.
332
333
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
334
            dtype=torch.get_default_dtype(),
335
            attn_backend_override=attn_backend_override,
336
        )
337

338
339
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
340
                self.attn_backend,
341
                attn_backend_override=attn_backend_override,
342
            )
343
        )
344

345
        if self.attn_backend not in {
346
347
348
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
349
350
        }:
            raise RuntimeError(
351
352
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
353

354
        self.is_flash_attn_backend = self.attn_backend in {
355
356
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
357
        }
358

359
360
361
362
363
364
365
366
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
367
368
369
370
371
372
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
373
374
375
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

376
    def forward(
377
378
379
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
380
381
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
382
        max_seqlen: int | None = None,  # Only used for Flash Attention
383
    ) -> torch.Tensor:
384
385
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
386

387
388
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
389
390
        batch_size = q.shape[1]

391
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
392
393
394
395
396
397
398

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(
            qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
399

400
        if self.is_flash_attn_backend:
401
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
418
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
419
            # Execute attention entry by entry for speed & less VRAM.
420
421
422
423
424
425
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
426
            outputs = []
427
428
429
430
431
432

            lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
            q_chunks = torch.split(q, lens, dim=1)
            k_chunks = torch.split(k, lens, dim=1)
            v_chunks = torch.split(v, lens, dim=1)
            for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
433
434
435
436
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
437
438
439
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
440
441
442
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
443
444
445
446
447
448
449
450
451
452
453

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
454
        act_layer: type[nn.Module] = QuickGELU,
455
456
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
457
        prefix: str = "",
458
        use_data_parallel: bool = False,
459
        attn_backend_override: AttentionBackendEnum | None = None,
460
461
462
463
464
465
466
467
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

468
469
470
471
472
473
474
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
475
            attn_backend_override=attn_backend_override,
476
477
478
479
480
481
482
483
484
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
485

486
    def forward(
487
488
489
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
490
491
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
492
        max_seqlen: int | None = None,  # Only used for Flash Attention
493
494
495
496
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
497
498
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
499
500
501
            max_seqlen=max_seqlen,
        )

502
503
504
505
506
507
508
509
510
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
511
        in_channels: int = 3,
512
513
514
515
516
517
518
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

519
        kernel_size = (temporal_patch_size, patch_size, patch_size)
520
521
        self.proj = Conv3dLayer(
            in_channels,
522
            embed_dim,
523
524
            kernel_size=kernel_size,
            stride=kernel_size,
525
526
            bias=False,
        )
527
528

    def forward(self, x: torch.Tensor) -> torch.Tensor:
529
530
531
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
532
533
534
535
536
537
538
539
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
540
        norm_layer: Callable[[int], nn.Module] | None = None,
541
        spatial_merge_size: int = 2,
542
        quant_config: QuantizationConfig | None = None,
543
        prefix: str = "",
544
        use_data_parallel: bool = False,
545
546
547
548
549
550
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
589
        quant_config: QuantizationConfig | None = None,
590
        prefix: str = "",
591
        use_data_parallel: bool = False,
592
        attn_backend_override: AttentionBackendEnum | None = None,
593
594
595
    ) -> None:
        super().__init__()

596
597
598
599
600
601
602
603
604
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
605

606
607
608
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

609
        self.spatial_merge_size = spatial_merge_size
610
611
        self.num_heads = num_heads
        self.embed_dim = embed_dim
612
613
614
615

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
616
            in_channels=in_channels,
617
618
619
620
621
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
622
623
624
625
626
627
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            rotary_dim=head_dim // 2,
            max_position=8192,
            is_neox_style=True,
        )
628

629
630
631
632
633
634
635
636
637
638
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
639
                    attn_backend_override=attn_backend_override,
640
641
642
643
                )
                for layer_idx in range(depth)
            ]
        )
644
645
646
647
648
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
649
            prefix=f"{prefix}.merger",
650
            use_data_parallel=use_data_parallel,
651
        )
652
        self.attn_backend = get_vit_attn_backend(
653
654
655
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
656
        )
657
658
659

    @property
    def dtype(self) -> torch.dtype:
660
        return self.patch_embed.proj.weight.dtype
661
662
663

    @property
    def device(self) -> torch.device:
664
        return self.patch_embed.proj.weight.device
665

666
667
668
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
669
        pos_ids = []
670
        max_grid_size = 0
671
672
673
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
695
            max_grid_size = max(max_grid_size, h, w)
696
        pos_ids = torch.cat(pos_ids, dim=0)
697
698
699
700

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

701
702
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
703
        return cos_combined, sin_combined
704

705
706
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
707
708
709
710
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
711
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
712
        return max_seqlen
713

714
715
716
    def forward(
        self,
        x: torch.Tensor,
717
        grid_thw: torch.Tensor | list[list[int]],
718
719
720
721
722
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

723
724
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
725
            grid_thw = np.array(grid_thw, dtype=np.int32)
726
727
        else:
            grid_thw_list = grid_thw.tolist()
728
            grid_thw = grid_thw.numpy()
729

730
        # compute position embedding
731
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
732
733

        # compute cu_seqlens
734
735
736
737
738
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
739
740
741

        # transformers
        x = x.unsqueeze(1)
742

743
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
744
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
745
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
746
        for blk in self.blocks:
747
748
749
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
750
751
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
752
753
                max_seqlen=max_seqlen,
            )
754
755
756

        # adapter
        x = self.merger(x)
757

758
759
        return x

760
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
761
762
763
764
765
766
767
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
768
        loaded_params: set[str] = set()
769
770

        for name, loaded_weight in weights:
771
            for param_name, weight_name, shard_id in stacked_params_mapping:
772
773
774
775
776
777
778
779
780
781
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
782
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
783
784
785
786
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

787

788
def _create_qwen2vl_field_factory(
789
    spatial_merge_size: int,
790
791
) -> Callable[
    [Mapping[str, torch.Tensor]],
792
    Mapping[str, MultiModalFieldConfig],
793
794
795
796
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
797
798
799
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
800
801
802

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
803
804
805
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
806
807
808

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
809
810
                "image", image_pixel_grid_sizes
            ),
811
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
812
813
                "image", image_embed_grid_sizes
            ),
814
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
815
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
816
817
                "video", video_grid_sizes
            ),
818
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
819
820
                "video", video_embed_grid_sizes
            ),
821
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
822
823
824
        )

    return _qwen2vl_field_config
825

826

Roger Wang's avatar
Roger Wang committed
827
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
828
829
830
831
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

832
833
    def _parse_image_data(
        self,
834
835
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
836
        if isinstance(data, dict):
837
838
839
840
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
841
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
842
            )
843
844
845
846

        return super()._parse_image_data(data)

    def _parse_video_data(
847
        self,
848
849
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
850
        if isinstance(data, dict):
851
852
853
854
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
855
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
856
            )
857
858
859
860

        return super()._parse_video_data(data)


861
862
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
863
864
        return self.ctx.get_hf_config(Qwen2VLConfig)

865
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
866
867
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
868
            use_fast=kwargs.pop("use_fast", True),
869
870
871
            **kwargs,
        )

872
873
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
874

875
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
876
877
        return {"image": None, "video": None}

878
879
880
881
882
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
883
884
885
886
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

887
888
889
890
891
892
893
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
894
        image_processor: Qwen2VLImageProcessor | None,
895
    ) -> tuple[ImageSize, int]:
896
897
898
899
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
900
        vision_config = hf_config.vision_config
901
902
903
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
904

905
906
907
908
909
910
911
912
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
913
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
914
        else:
915
            preprocessed_size = ImageSize(width=image_width, height=image_height)
916

917
918
919
920
921
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
922
923
924
925
926
927
928
929
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

930
    def get_num_image_tokens(
931
932
933
934
        self,
        *,
        image_width: int,
        image_height: int,
935
        image_processor: Qwen2VLImageProcessor | None,
936
937
938
939
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
940
            num_frames=1,
941
            image_processor=image_processor,
942
943
944
        )
        return num_image_tokens

945
    def get_num_video_tokens(
946
947
948
949
950
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
951
        image_processor: Qwen2VLImageProcessor | None,
952
953
954
955
956
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
957
            image_processor=image_processor,
958
959
960
        )
        return num_video_tokens

961
    def get_image_size_with_most_features(self) -> ImageSize:
962
963
964
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
965
            num_frames=1,
966
            image_processor=None,
967
968
969
        )
        return max_image_size

970
971
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
972

973
        return self.get_num_image_tokens(
974
975
            image_width=target_width,
            image_height=target_height,
976
            image_processor=None,
977
        )
978

979
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
980
        target_width, target_height = self.get_image_size_with_most_features()
981

982
        num_frames = start_num_frames
983
984
985

        while True:
            next_num_frames = num_frames + 1
986
            next_max_tokens = self.get_num_video_tokens(
987
988
989
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
990
                image_processor=None,
991
            )
992

993
            if next_max_tokens > max_tokens:
994
995
996
997
998
999
                break

            num_frames = next_num_frames

        return num_frames

1000
1001
1002
1003
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1004
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1005
1006
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1007

1008
        max_total_frames = self._get_max_video_frames(seq_len)
1009
1010
1011
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1012

1013
        return max(max_frames_per_video, 1)
1014

1015
1016
1017
1018
1019
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1020
        target_width, target_height = self.get_image_size_with_most_features()
1021

1022
        return self.get_num_video_tokens(
1023
1024
            image_width=target_width,
            image_height=target_height,
1025
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1026
            image_processor=None,
1027
1028
        )

1029
1030

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1031
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1032
1033
1034
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1035
        hf_processor = self.info.get_hf_processor()
1036
1037
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1038

1039
1040
1041
1042
1043
1044
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1045
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1046
1047
1048
1049
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1050
1051
1052
1053
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1054

1055
1056
1057
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1058
        return {
1059
1060
1061
1062
1063
1064
1065
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1066
1067
                width=target_width,
                height=target_height,
1068
                num_frames=target_num_frames,
1069
                num_videos=num_videos,
1070
                overrides=video_overrides,
1071
            ),
1072
1073
        }

1074

1075
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1076
    def _get_data_parser(self) -> MultiModalDataParser:
1077
        return Qwen2VLMultiModalDataParser(
1078
1079
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1080

1081
    def _get_prompt_updates(
1082
1083
        self,
        mm_items: MultiModalDataItems,
1084
        hf_processor_mm_kwargs: Mapping[str, Any],
1085
        out_mm_kwargs: MultiModalKwargsItems,
1086
    ) -> Sequence[PromptUpdate]:
1087
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1088
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1089
1090
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1091
1092

        placeholder = {
1093
1094
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1095
        }
1096

1097
1098
1099
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1100
1101
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1102
1103
            assert isinstance(grid_thw, torch.Tensor)

1104
1105
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1106
1107
1108
1109

        return [
            PromptReplacement(
                modality=modality,
1110
                target=[placeholder[modality]],
1111
1112
1113
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1114
        ]
1115

1116
1117
1118
1119
1120
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1121
        return _create_qwen2vl_field_factory(
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1134
    # To ensure correct weight loading and mapping.
1135
1136
1137
1138
1139
1140
1141
1142
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1143
1144
        }
    )
1145

1146
1147
    supports_encoder_tp_data = True

1148
1149
1150
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1151
        mm_features: list[MultiModalFeatureSpec],
1152
    ) -> tuple[torch.Tensor, int]:
1153
1154
1155
1156
1157
1158
1159
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1160

1161
        hf_config = self.config
1162
1163
1164
1165
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1166
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1167
1168
1169

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1170
1171
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1198
                t, h, w = image_grid_thw[image_index]
1199
1200
1201
1202
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1203
                t, h, w = video_grid_thw[video_index]
1204
1205
1206
1207
1208
1209
1210
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1211
1212
1213
1214
1215
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1216
1217
            text_len = ed - st

1218
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1219
            llm_pos_ids_list.append(
1220
1221
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1222

1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1234

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1247
            llm_pos_ids_list.append(
1248
1249
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1250
1251
1252
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1253
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1254
1255
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1256
1257
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1258
1259

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1260
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1261
1262
1263

        return llm_positions, mrope_position_delta

1264
    @classmethod
1265
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1266
1267
1268
1269
1270
1271
1272
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1273
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1274
        super().__init__()
1275
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1276
1277
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1278

1279
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1280
1281
1282
        self.config = config
        self.multimodal_config = multimodal_config

1283
1284
1285
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1286
1287
1288
1289
1290
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1291
1292
1293
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1294
                quant_config=quant_config,
1295
                prefix=maybe_prefix(prefix, "visual"),
1296
                use_data_parallel=self.use_data_parallel,
1297
                attn_backend_override=attn_backend_override,
1298
1299
1300
            )
        else:
            self.visual = None
1301

1302
1303
1304
1305
1306
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1307

1308
        self.make_empty_intermediate_tensors = (
1309
1310
            self.language_model.make_empty_intermediate_tensors
        )
1311
1312

    def _parse_and_validate_image_input(
1313
        self, **kwargs: object
1314
    ) -> Qwen2VLImageInputs | None:
1315
        pixel_values = kwargs.pop("pixel_values", None)
1316
        image_embeds = kwargs.pop("image_embeds", None)
1317
1318
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1319
        if pixel_values is None and image_embeds is None:
1320
1321
            return None

1322
        if pixel_values is not None:
1323
1324
1325
1326
1327
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1328
1329

        if image_embeds is not None:
1330
1331
1332
1333
1334
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1335
1336

    def _parse_and_validate_video_input(
1337
        self, **kwargs: object
1338
    ) -> Qwen2VLVideoInputs | None:
1339
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1340
        video_embeds = kwargs.pop("video_embeds", None)
1341
1342
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1343
        if pixel_values_videos is None and video_embeds is None:
1344
1345
            return None

1346
1347
1348
1349
1350
1351
1352
1353
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1354
1355
1356
1357
1358
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1359

1360
    def _process_image_input(
1361
1362
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1363
1364
1365
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1366
        if image_input["type"] == "image_embeds":
1367
            image_embeds = image_input["image_embeds"]
1368
        else:
1369
            pixel_values = image_input["pixel_values"]
1370
1371

            if self.use_data_parallel:
1372
                return run_dp_sharded_mrope_vision_model(
1373
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1374
                )
1375
            else:
1376
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1377
1378
1379

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1380
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1381
        return image_embeds.split(sizes)
1382
1383

    def _process_video_input(
1384
1385
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1386
1387
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1388

1389
        if video_input["type"] == "video_embeds":
1390
            video_embeds = video_input["video_embeds"]
1391
        else:
1392
            pixel_values_videos = video_input["pixel_values_videos"]
1393
            if self.use_data_parallel:
1394
                return run_dp_sharded_mrope_vision_model(
1395
1396
1397
1398
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
1399
                )
1400
            else:
1401
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1402

1403
1404
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1405
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1406
        return video_embeds.split(sizes)
1407
1408
1409
1410
1411
1412
1413

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1424
1425

        return modalities
1426

1427
1428
1429
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1430
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1431
1432
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1433
            return []
1434

1435
1436
1437
1438
1439
1440
1441
1442
1443
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1444
1445
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1446
1447
1448
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1449
                multimodal_embeddings += tuple(video_embeddings)
1450
1451
1452

        return multimodal_embeddings

1453
1454
1455
1456
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1457
1458
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1459
        **kwargs: object,
1460
    ) -> torch.Tensor | IntermediateTensors:
1461
1462
1463
1464
1465
1466
1467
1468
1469
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1470
1471
1472
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1473
        """
1474

1475
        if intermediate_tensors is not None:
1476
            inputs_embeds = None
1477

1478
        hidden_states = self.language_model.model(
1479
1480
            input_ids=input_ids,
            positions=positions,
1481
            intermediate_tensors=intermediate_tensors,
1482
1483
1484
1485
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1486
1487
1488
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1489
    ) -> torch.Tensor | None:
1490
        return self.language_model.compute_logits(hidden_states)
1491

1492
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1493
1494
1495
1496
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1497
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1498
1499
1500
1501
1502
1503
1504

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1505
1506
1507
            connector="visual.merger.",
            tower_model="visual.",
        )
1508
1509
1510
1511
1512
1513
1514
1515
1516


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1517
        size: dict[str, int] | None = None,
1518
1519
1520
1521
1522
1523
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1524
                "longest_edge": size["max_pixels"],
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
1535
        tokenizer: TokenizerLike,
1536
1537
1538
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1539
1540
1541
1542
1543
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1544
1545
            **kwargs,
        )
1546
1547
1548
1549
1550


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1551
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1563
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1564
1565


1566
1567
1568
1569
1570
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1571
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1572
1573
1574
1575
1576
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1577

1578
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1579
1580
1581
1582
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1583
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)