qwen2_vl.py 53.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32

33
import numpy as np
34
35
import torch
import torch.nn as nn
36
from einops import rearrange
37
from transformers import BatchFeature
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.attention.backends.registry import AttentionBackendEnum
47
48
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig
49
from vllm.config.multimodal import BaseDummyOptions
50
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
51
52
53
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
54
from vllm.model_executor.layers.conv import Conv3dLayer
55
56
57
58
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
59
from vllm.model_executor.layers.quantization import QuantizationConfig
60
from vllm.model_executor.layers.rotary_embedding import get_rope
61
from vllm.model_executor.layers.rotary_embedding.common import (
62
    ApplyRotaryEmb,
63
)
64
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
65
from vllm.model_executor.models.module_mapping import MultiModelKeys
66
from vllm.multimodal import MULTIMODAL_REGISTRY
67
68
69
70
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
71
    MultiModalFeatureSpec,
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
89
from vllm.multimodal.profiling import BaseDummyInputsBuilder
90
from vllm.sequence import IntermediateTensors
91
from vllm.tokenizers import TokenizerLike
92
from vllm.utils.tensor_schema import TensorSchema, TensorShape
93

94
95
96
97
98
99
100
101
102
103
104
105
106
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
107
108
109
110
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
111

112
113
logger = init_logger(__name__)

114
# For profile run
115
_MAX_FRAMES_PER_VIDEO = 14
116

117
118
119
# === Vision Inputs === #


120
class Qwen2VLImagePixelInputs(TensorSchema):
121
    """
122
123
124
125
126
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
127

128
    Historical context:
129
        - pixel_values shape: (num_patches, num_channels * patch_size *
130
131
132
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
133
    """
134

135
    type: Literal["pixel_values"]
136

137
138
139
140
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
141

142
143
144
145
146
147
148
149
150
151
152
153
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
154

155
156
157
158
159
160
161
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
162
    """
163

164
    type: Literal["image_embeds"]
165

166
167
168
169
170
171
172
173
174
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
175
176


177
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
178
179


180
181
182
183
184
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
185
        - ctps: Number of channels * temporal_patch_size * patch_size *
186
187
          patch_size
        - nv: Number of videos
188

189
    Historical context:
190
        - pixel_values_videos shape: (num_patches, num_channels *
191
192
193
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
194
    """
195

196
    type: Literal["pixel_values_videos"]
197

198
199
200
201
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
202

203
204
205
206
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
207
208


209
210
211
212
213
214
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
215

216
217
218
219
220
221
222
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
223
    """
224

225
    type: Literal["video_embeds"]
226

227
228
229
230
231
232
233
234
235
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
236
237


238
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
239

240
241
242
243
244
245
246
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
247
        hidden_features: int,
248
        act_layer: type[nn.Module] = QuickGELU,
249
        quant_config: QuantizationConfig | None = None,
250
        multimodal_config: MultiModalConfig | None = None,
251
        prefix: str = "",
252
253
    ):
        super().__init__()
254
255
256
257
258
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
259
260
261
262
263
264
265
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
266
        self.act = act_layer()
267
268
269
270
271
272
273
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
274
275
276
277
278
279
280
281
282
283
284

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
285
286
287
        embed_dim: int,
        num_heads: int,
        projection_size: int,
288
        quant_config: QuantizationConfig | None = None,
289
        multimodal_config: MultiModalConfig | None = None,
290
        prefix: str = "",
291
292
293
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
294
295
296
297
298
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
299
300
301
302
303
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
304
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
305
        self.hidden_size_per_attention_head = dist_utils.divide(
306
307
            projection_size, num_heads
        )
308
        self.num_attention_heads_per_partition = dist_utils.divide(
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
326

327
328
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
329
            head_size=self.hidden_size_per_attention_head,
330
            scale=self.hidden_size_per_attention_head**-0.5,
331
            multimodal_config=multimodal_config,
332
        )
333

334
335
        self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

336
337
338
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
339
340
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)
341
342
343
344

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

345
346
347
348
349
350
351
352
353
        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

354
        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
355
356
357
358
359
360
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
361
362
363
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

364
    def forward(
365
366
367
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
368
369
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
370
        max_seqlen: int | None = None,  # Only used for Flash Attention
371
    ) -> torch.Tensor:
372
373
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
374

375
376
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
377

378
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
379
380
381

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
382
383
384
385
        qk_rotated = self.apply_rotary_emb(
            qk_concat,
            rotary_pos_emb_cos,
            rotary_pos_emb_sin,
386
387
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
388

389
390
391
392
393
394
395
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
396

397
        context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
398
399
400
401
402
403
404
405
406
407
408

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
409
        act_layer: type[nn.Module] = QuickGELU,
410
411
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
412
        multimodal_config: MultiModalConfig | None = None,
413
        prefix: str = "",
414
415
416
417
418
419
420
421
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

422
423
424
425
426
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
427
            multimodal_config=multimodal_config,
428
429
430
431
432
433
434
            prefix=f"{prefix}.attn",
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
435
            multimodal_config=multimodal_config,
436
437
            prefix=f"{prefix}.mlp",
        )
438

439
    def forward(
440
441
442
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
443
444
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
445
        max_seqlen: int | None = None,  # Only used for Flash Attention
446
447
448
449
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
450
451
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
452
453
454
            max_seqlen=max_seqlen,
        )

455
456
457
458
459
460
461
462
463
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
464
        in_channels: int = 3,
465
466
467
468
469
470
471
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

472
        kernel_size = (temporal_patch_size, patch_size, patch_size)
473
474
        self.proj = Conv3dLayer(
            in_channels,
475
            embed_dim,
476
477
            kernel_size=kernel_size,
            stride=kernel_size,
478
479
            bias=False,
        )
480
481

    def forward(self, x: torch.Tensor) -> torch.Tensor:
482
483
484
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
485
486
487
488
489
490
491
492
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
493
        norm_layer: Callable[[int], nn.Module] | None = None,
494
        spatial_merge_size: int = 2,
495
        quant_config: QuantizationConfig | None = None,
496
        multimodal_config: MultiModalConfig | None = None,
497
        prefix: str = "",
498
499
    ) -> None:
        super().__init__()
500
501
502
503
504
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
505
506
507
508
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
547
        quant_config: QuantizationConfig | None = None,
548
        multimodal_config: MultiModalConfig | None = None,
549
        prefix: str = "",
550
551
552
    ) -> None:
        super().__init__()

553
554
555
556
557
558
559
560
561
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
562

563
564
565
566
567
        self.use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
568
569
        self.out_hidden_size = vision_config.hidden_size

570
        self.spatial_merge_size = spatial_merge_size
571
572
        self.num_heads = num_heads
        self.embed_dim = embed_dim
573
574
575
576

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
577
            in_channels=in_channels,
578
579
580
581
582
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
583
584
585
586
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
587
            rope_parameters={"partial_rotary_factor": 0.5},
588
        )
589

590
591
592
593
594
595
596
597
598
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
599
                    multimodal_config=multimodal_config,
600
601
602
603
                )
                for layer_idx in range(depth)
            ]
        )
604
605
606
607
608
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
609
            prefix=f"{prefix}.merger",
610
611
612
613
            multimodal_config=multimodal_config,
        )
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend if multimodal_config else None
614
        )
615
        self.attn_backend = get_vit_attn_backend(
616
617
618
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
619
        )
620
621
622

    @property
    def dtype(self) -> torch.dtype:
623
        return self.patch_embed.proj.weight.dtype
624
625
626

    @property
    def device(self) -> torch.device:
627
        return self.patch_embed.proj.weight.device
628

629
630
631
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
632
        pos_ids = []
633
        max_grid_size = 0
634
635
636
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
658
            max_grid_size = max(max_grid_size, h, w)
659
        pos_ids = torch.cat(pos_ids, dim=0)
660
661
662
663

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

664
665
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
666
        return cos_combined, sin_combined
667

668
669
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
670
671
672
673
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
674
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
675
        return max_seqlen
676

677
678
679
    def forward(
        self,
        x: torch.Tensor,
680
        grid_thw: torch.Tensor | list[list[int]],
681
682
683
684
685
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

686
687
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
688
            grid_thw = np.array(grid_thw, dtype=np.int32)
689
690
        else:
            grid_thw_list = grid_thw.tolist()
691
            grid_thw = grid_thw.numpy()
692

693
        # compute position embedding
694
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
695
696

        # compute cu_seqlens
697
698
699
700
701
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
702
703
704

        # transformers
        x = x.unsqueeze(1)
705

706
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
707
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
708
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
709
        for blk in self.blocks:
710
711
712
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
713
714
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
715
716
                max_seqlen=max_seqlen,
            )
717
718
719

        # adapter
        x = self.merger(x)
720

721
722
        return x

723
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
724
725
726
727
728
729
730
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
731
        loaded_params: set[str] = set()
732
733

        for name, loaded_weight in weights:
734
            for param_name, weight_name, shard_id in stacked_params_mapping:
735
736
737
738
739
740
741
742
743
744
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
745
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
746
747
748
749
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

750

751
def _create_qwen2vl_field_factory(
752
    spatial_merge_size: int,
753
754
) -> Callable[
    [Mapping[str, torch.Tensor]],
755
    Mapping[str, MultiModalFieldConfig],
756
757
758
759
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
760
761
762
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
763
764
765

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
766
767
768
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
769
770
771

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
772
773
                "image", image_pixel_grid_sizes
            ),
774
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
775
776
                "image", image_embed_grid_sizes
            ),
777
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
778
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
779
780
                "video", video_grid_sizes
            ),
781
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
782
783
                "video", video_embed_grid_sizes
            ),
784
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
785
786
787
        )

    return _qwen2vl_field_config
788

789

Roger Wang's avatar
Roger Wang committed
790
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
791
792
793
794
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

795
796
    def _parse_image_data(
        self,
797
798
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
799
        if isinstance(data, dict):
800
801
802
803
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
804
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
805
            )
806
807
808
809

        return super()._parse_image_data(data)

    def _parse_video_data(
810
        self,
811
812
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
813
        if isinstance(data, dict):
814
815
816
817
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
818
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
819
            )
820
821
822
823

        return super()._parse_video_data(data)


824
825
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
826
827
        return self.ctx.get_hf_config(Qwen2VLConfig)

828
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
829
830
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
831
            use_fast=kwargs.pop("use_fast", True),
832
833
834
            **kwargs,
        )

835
836
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
837

838
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
839
840
        return {"image": None, "video": None}

841
842
843
844
845
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
846
847
848
849
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

850
851
852
853
854
855
856
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
857
        image_processor: Qwen2VLImageProcessor | None,
858
    ) -> tuple[ImageSize, int]:
859
860
861
862
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
863
        vision_config = hf_config.vision_config
864
865
866
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
867

868
869
870
871
872
873
874
875
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
876
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
877
        else:
878
            preprocessed_size = ImageSize(width=image_width, height=image_height)
879

880
881
882
883
884
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
885
886
887
888
889
890
891
892
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

893
    def get_num_image_tokens(
894
895
896
897
        self,
        *,
        image_width: int,
        image_height: int,
898
        image_processor: Qwen2VLImageProcessor | None,
899
900
901
902
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
903
            num_frames=1,
904
            image_processor=image_processor,
905
906
907
        )
        return num_image_tokens

908
    def get_num_video_tokens(
909
910
911
912
913
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
914
        image_processor: Qwen2VLImageProcessor | None,
915
916
917
918
919
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
920
            image_processor=image_processor,
921
922
923
        )
        return num_video_tokens

924
    def get_image_size_with_most_features(self) -> ImageSize:
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
        # NOTE: Simply processing a huge size with _get_vision_info might not give a
        # size that maximizes the number of featrues, i.e., the number of (merged)
        # patches. This is because the number of patches limits the allowed aspect
        # ratios. For example, suppose the maximum number of patches is 1280. A square
        # image cannot be broken down into 1280 patches, so feeding a giant square image
        # into _get_vision_info will not yield a size that maximizes the number of
        # patches. Therefore, we directly factorize the maximum number of patches into
        # height and width. The tricky part is to avoid extreme aspect ratios (>200 for
        # qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
        # patches and retry. This is safe because the processor does not accept extreme
        # aspect ratios, so there is no valid post-resize image with the number of
        # patches that yields extreme aspect ratios.

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        image_processor = self.get_image_processor()
        max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"]
        unit = patch_size * merge_size
        max_seq_len = max_pixels // (unit * unit)

        def closest_factor_pair(n: int) -> tuple[int, int]:
            # left <= right
            for d in range(math.isqrt(n), 0, -1):
                if n % d == 0:
                    return d, n // d
            return 1, n

        height_factor, width_factor = 1, max_seq_len
        for seq_len in range(max_seq_len, 0, -1):
            height_factor, width_factor = closest_factor_pair(seq_len)
            if width_factor / height_factor <= 200:
                break

        return ImageSize(width=unit * width_factor, height=unit * height_factor)
961

962
963
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
964

965
        return self.get_num_image_tokens(
966
967
            image_width=target_width,
            image_height=target_height,
968
            image_processor=None,
969
        )
970

971
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
972
        target_width, target_height = self.get_image_size_with_most_features()
973

974
        num_frames = start_num_frames
975
976
977

        while True:
            next_num_frames = num_frames + 1
978
            next_max_tokens = self.get_num_video_tokens(
979
980
981
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
982
                image_processor=None,
983
            )
984

985
            if next_max_tokens > max_tokens:
986
987
988
989
990
991
                break

            num_frames = next_num_frames

        return num_frames

992
993
994
995
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
996
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
997
998
    ) -> int:
        max_videos = mm_counts.get("video", 0)
999

1000
        max_total_frames = self._get_max_video_frames(seq_len)
1001
1002
1003
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1004

1005
        return max(max_frames_per_video, 1)
1006

1007
1008
1009
1010
1011
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1012
        target_width, target_height = self.get_image_size_with_most_features()
1013

1014
        return self.get_num_video_tokens(
1015
1016
            image_width=target_width,
            image_height=target_height,
1017
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1018
            image_processor=None,
1019
1020
        )

1021
1022

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1023
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1024
1025
1026
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1027
        hf_processor = self.info.get_hf_processor()
1028
1029
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1030

1031
1032
1033
1034
1035
1036
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1037
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1038
1039
1040
1041
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1042
1043
1044
1045
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1046

1047
1048
1049
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1050
        return {
1051
1052
1053
1054
1055
1056
1057
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1058
1059
                width=target_width,
                height=target_height,
1060
                num_frames=target_num_frames,
1061
                num_videos=num_videos,
1062
                overrides=video_overrides,
1063
            ),
1064
1065
        }

1066

1067
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1068
    def _get_data_parser(self) -> MultiModalDataParser:
1069
        return Qwen2VLMultiModalDataParser(
1070
1071
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1072

1073
    def _get_prompt_updates(
1074
1075
        self,
        mm_items: MultiModalDataItems,
1076
        hf_processor_mm_kwargs: Mapping[str, Any],
1077
        out_mm_kwargs: MultiModalKwargsItems,
1078
    ) -> Sequence[PromptUpdate]:
1079
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1080
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1081
1082
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1083
1084

        placeholder = {
1085
1086
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1087
        }
1088

1089
1090
1091
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1092
1093
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1094
1095
            assert isinstance(grid_thw, torch.Tensor)

1096
1097
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1098
1099
1100
1101

        return [
            PromptReplacement(
                modality=modality,
1102
                target=[placeholder[modality]],
1103
1104
1105
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1106
        ]
1107

1108
1109
1110
1111
1112
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1113
        return _create_qwen2vl_field_factory(
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1126
    # To ensure correct weight loading and mapping.
1127
1128
1129
1130
1131
1132
1133
1134
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1135
1136
        }
    )
1137

1138
1139
    supports_encoder_tp_data = True

1140
1141
1142
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1143
        mm_features: list[MultiModalFeatureSpec],
1144
    ) -> tuple[torch.Tensor, int]:
1145
1146
1147
1148
1149
1150
1151
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1152

1153
        hf_config = self.config
1154
1155
1156
1157
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1158
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1159
1160
1161

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1162
1163
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1190
                t, h, w = image_grid_thw[image_index]
1191
1192
1193
1194
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1195
                t, h, w = video_grid_thw[video_index]
1196
1197
1198
1199
1200
1201
1202
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1203
1204
1205
1206
1207
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1208
1209
            text_len = ed - st

1210
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1211
            llm_pos_ids_list.append(
1212
1213
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1214

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1226

1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1239
            llm_pos_ids_list.append(
1240
1241
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1242
1243
1244
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1245
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1246
1247
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1248
1249
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1250
1251

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1252
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1253
1254
1255

        return llm_positions, mrope_position_delta

1256
    @classmethod
1257
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1258
1259
1260
1261
1262
1263
1264
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1265
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1266
        super().__init__()
1267
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1268
1269
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1270

1271
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1272
1273
1274
        self.config = config
        self.multimodal_config = multimodal_config

1275
1276
1277
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1278
1279
1280
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1281
                quant_config=quant_config,
1282
                multimodal_config=multimodal_config,
1283
1284
1285
1286
                prefix=maybe_prefix(prefix, "visual"),
            )
        else:
            self.visual = None
1287

1288
1289
1290
1291
1292
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1293

1294
        self.make_empty_intermediate_tensors = (
1295
1296
            self.language_model.make_empty_intermediate_tensors
        )
1297
1298

    def _parse_and_validate_image_input(
1299
        self, **kwargs: object
1300
    ) -> Qwen2VLImageInputs | None:
1301
        pixel_values = kwargs.pop("pixel_values", None)
1302
        image_embeds = kwargs.pop("image_embeds", None)
1303
1304
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1305
        if pixel_values is None and image_embeds is None:
1306
1307
            return None

1308
        if pixel_values is not None:
1309
1310
1311
1312
1313
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1314
1315

        if image_embeds is not None:
1316
1317
1318
1319
1320
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1321
1322

    def _parse_and_validate_video_input(
1323
        self, **kwargs: object
1324
    ) -> Qwen2VLVideoInputs | None:
1325
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1326
        video_embeds = kwargs.pop("video_embeds", None)
1327
1328
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1329
        if pixel_values_videos is None and video_embeds is None:
1330
1331
            return None

1332
1333
1334
1335
1336
1337
1338
1339
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1340
1341
1342
1343
1344
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1345

1346
    def _process_image_input(
1347
1348
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1349
1350
1351
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1352
        if image_input["type"] == "image_embeds":
1353
            image_embeds = image_input["image_embeds"]
1354
        else:
1355
            pixel_values = image_input["pixel_values"]
1356
1357

            if self.use_data_parallel:
1358
                return run_dp_sharded_mrope_vision_model(
1359
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1360
                )
1361
            else:
1362
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1363
1364
1365

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1366
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1367
        return image_embeds.split(sizes)
1368
1369

    def _process_video_input(
1370
1371
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1372
1373
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1374

1375
        if video_input["type"] == "video_embeds":
1376
            video_embeds = video_input["video_embeds"]
1377
        else:
1378
            pixel_values_videos = video_input["pixel_values_videos"]
1379
            if self.use_data_parallel:
1380
                return run_dp_sharded_mrope_vision_model(
1381
1382
1383
1384
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
1385
                )
1386
            else:
1387
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1388

1389
1390
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1391
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1392
        return video_embeds.split(sizes)
1393
1394
1395
1396
1397
1398
1399

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1410
1411

        return modalities
1412

1413
1414
1415
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1416
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1417
1418
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1419
            return []
1420

1421
1422
1423
1424
1425
1426
1427
1428
1429
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1430
1431
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1432
1433
1434
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1435
                multimodal_embeddings += tuple(video_embeddings)
1436
1437
1438

        return multimodal_embeddings

1439
1440
1441
1442
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1443
1444
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1445
        **kwargs: object,
1446
    ) -> torch.Tensor | IntermediateTensors:
1447
1448
1449
1450
1451
1452
1453
1454
1455
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1456
1457
1458
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1459
        """
1460

1461
        if intermediate_tensors is not None:
1462
            inputs_embeds = None
1463

1464
        hidden_states = self.language_model.model(
1465
1466
            input_ids=input_ids,
            positions=positions,
1467
            intermediate_tensors=intermediate_tensors,
1468
1469
1470
1471
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1472
1473
1474
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1475
    ) -> torch.Tensor | None:
1476
        return self.language_model.compute_logits(hidden_states)
1477

1478
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1479
1480
1481
1482
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1483
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1484
1485
1486
1487
1488
1489
1490

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1491
1492
1493
            connector="visual.merger.",
            tower_model="visual.",
        )
1494

1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

1514
1515
1516
1517
1518
1519
1520
1521

class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1522
        size: dict[str, int] | None = None,
1523
1524
1525
1526
1527
1528
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1529
                "longest_edge": size["max_pixels"],
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
1540
        tokenizer: TokenizerLike,
1541
1542
1543
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1544
1545
1546
1547
1548
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1549
1550
            **kwargs,
        )
1551
1552
1553
1554
1555


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1556
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1568
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1569
1570


1571
1572
1573
1574
1575
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1576
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1577
1578
1579
1580
1581
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1582

1583
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1584
1585
1586
1587
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1588
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)