qwen2_vl.py 51.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32

33
import numpy as np
34
35
import torch
import torch.nn as nn
36
from einops import rearrange
37
from transformers import BatchFeature
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.config import VllmConfig
47
from vllm.config.multimodal import BaseDummyOptions
48
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
49
50
51
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
52
from vllm.model_executor.layers.attention import MMEncoderAttention
53
from vllm.model_executor.layers.conv import Conv3dLayer
54
55
56
57
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
58
from vllm.model_executor.layers.quantization import QuantizationConfig
59
from vllm.model_executor.layers.rotary_embedding import get_rope
60
from vllm.model_executor.layers.rotary_embedding.common import (
61
    ApplyRotaryEmb,
62
)
63
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
64
from vllm.model_executor.models.module_mapping import MultiModelKeys
65
from vllm.multimodal import MULTIMODAL_REGISTRY
66
67
68
69
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
70
    MultiModalFeatureSpec,
71
72
73
74
75
76
77
78
79
80
81
82
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
83
    BaseDummyInputsBuilder,
84
85
86
87
88
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
89
from vllm.sequence import IntermediateTensors
90
from vllm.tokenizers import TokenizerLike
91
from vllm.utils.tensor_schema import TensorSchema, TensorShape
92
from vllm.v1.attention.backends.registry import AttentionBackendEnum
93

94
95
96
97
98
99
100
101
102
103
104
105
106
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
107
108
from .vision import (
    get_vit_attn_backend,
109
    is_vit_use_data_parallel,
110
111
    run_dp_sharded_mrope_vision_model,
)
112

113
114
logger = init_logger(__name__)

115
# For profile run
116
_MAX_FRAMES_PER_VIDEO = 14
117

118
119
120
# === Vision Inputs === #


121
class Qwen2VLImagePixelInputs(TensorSchema):
122
    """
123
124
125
126
127
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
128

129
    Historical context:
130
        - pixel_values shape: (num_patches, num_channels * patch_size *
131
132
133
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
134
    """
135

136
    type: Literal["pixel_values"]
137

138
139
140
141
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
142

143
144
145
146
147
148
149
150
151
152
153
154
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
155

156
157
158
159
160
161
162
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
163
    """
164

165
    type: Literal["image_embeds"]
166

167
168
169
170
171
172
173
174
175
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
176
177


178
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
179
180


181
182
183
184
185
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
186
        - ctps: Number of channels * temporal_patch_size * patch_size *
187
188
          patch_size
        - nv: Number of videos
189

190
    Historical context:
191
        - pixel_values_videos shape: (num_patches, num_channels *
192
193
194
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
195
    """
196

197
    type: Literal["pixel_values_videos"]
198

199
200
201
202
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
203

204
205
206
207
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
208
209


210
211
212
213
214
215
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
216

217
218
219
220
221
222
223
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
224
    """
225

226
    type: Literal["video_embeds"]
227

228
229
230
231
232
233
234
235
236
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
237
238


239
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
240

241
242
243
244
245
246
247
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
248
        hidden_features: int,
249
        act_layer: type[nn.Module] = QuickGELU,
250
        quant_config: QuantizationConfig | None = None,
251
        prefix: str = "",
252
253
    ):
        super().__init__()
254
        use_data_parallel = is_vit_use_data_parallel()
255
256
257
258
259
260
261
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
262
        self.act = act_layer()
263
264
265
266
267
268
269
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
270
271
272
273
274
275
276
277
278
279
280

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
281
282
283
        embed_dim: int,
        num_heads: int,
        projection_size: int,
284
        quant_config: QuantizationConfig | None = None,
285
        prefix: str = "",
286
287
288
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
289
        use_data_parallel = is_vit_use_data_parallel()
290
291
292
293
294
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
295
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
296
        self.hidden_size_per_attention_head = dist_utils.divide(
297
298
            projection_size, num_heads
        )
299
        self.num_attention_heads_per_partition = dist_utils.divide(
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
317

318
319
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
320
            head_size=self.hidden_size_per_attention_head,
321
            scale=self.hidden_size_per_attention_head**-0.5,
322
        )
323

324
325
        self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

326
327
328
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
329
330
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)
331
332
333
334

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

335
336
337
338
339
340
341
342
343
        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

344
        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
345
346
347
348
349
350
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
351
352
353
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

354
    def forward(
355
356
357
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
358
359
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
360
        max_seqlen: int | None = None,  # Only used for Flash Attention
361
    ) -> torch.Tensor:
362
363
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
364

365
366
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
367

368
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
369
370
371

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
372
373
374
375
        qk_rotated = self.apply_rotary_emb(
            qk_concat,
            rotary_pos_emb_cos,
            rotary_pos_emb_sin,
376
377
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
378

379
380
381
382
383
384
385
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
386

387
        context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
388
389
390
391
392
393
394
395
396
397
398

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
399
        act_layer: type[nn.Module] = QuickGELU,
400
401
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
402
        prefix: str = "",
403
404
405
406
407
408
409
410
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

411
412
413
414
415
416
417
418
419
420
421
422
423
424
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
425

426
    def forward(
427
428
429
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
430
431
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
432
        max_seqlen: int | None = None,  # Only used for Flash Attention
433
434
435
436
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
437
438
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
439
440
441
            max_seqlen=max_seqlen,
        )

442
443
444
445
446
447
448
449
450
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
451
        in_channels: int = 3,
452
453
454
455
456
457
458
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

459
        kernel_size = (temporal_patch_size, patch_size, patch_size)
460
461
        self.proj = Conv3dLayer(
            in_channels,
462
            embed_dim,
463
464
            kernel_size=kernel_size,
            stride=kernel_size,
465
466
            bias=False,
        )
467
468

    def forward(self, x: torch.Tensor) -> torch.Tensor:
469
470
471
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
472
473
474
475
476
477
478
479
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
480
        norm_layer: Callable[[int], nn.Module] | None = None,
481
        spatial_merge_size: int = 2,
482
        quant_config: QuantizationConfig | None = None,
483
        prefix: str = "",
484
485
    ) -> None:
        super().__init__()
486
        use_data_parallel = is_vit_use_data_parallel()
487
488
489
490
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
529
        quant_config: QuantizationConfig | None = None,
530
        prefix: str = "",
531
532
533
    ) -> None:
        super().__init__()

534
535
536
537
538
539
540
541
542
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
543

544
        self.use_data_parallel = is_vit_use_data_parallel()
545
546
        self.out_hidden_size = vision_config.hidden_size

547
        self.spatial_merge_size = spatial_merge_size
548
549
        self.num_heads = num_heads
        self.embed_dim = embed_dim
550
551
552
553

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
554
            in_channels=in_channels,
555
556
557
558
559
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
560
561
562
563
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
564
            rope_parameters={"partial_rotary_factor": 0.5},
565
        )
566

567
568
569
570
571
572
573
574
575
576
577
578
579
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(depth)
            ]
        )
580
581
582
583
584
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
585
            prefix=f"{prefix}.merger",
586
        )
587
        self.attn_backend = get_vit_attn_backend(
588
589
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
590
        )
591
592
593

    @property
    def dtype(self) -> torch.dtype:
594
        return self.patch_embed.proj.weight.dtype
595
596
597

    @property
    def device(self) -> torch.device:
598
        return self.patch_embed.proj.weight.device
599

600
601
602
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
603
        pos_ids = []
604
        max_grid_size = 0
605
606
607
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
629
            max_grid_size = max(max_grid_size, h, w)
630
        pos_ids = torch.cat(pos_ids, dim=0)
631
632
633
634

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

635
636
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
637
        return cos_combined, sin_combined
638

639
640
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
641
642
643
644
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
645
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
646
        return max_seqlen
647

648
649
650
    def forward(
        self,
        x: torch.Tensor,
651
        grid_thw: torch.Tensor | list[list[int]],
652
653
654
655
656
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

657
658
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
659
            grid_thw = np.array(grid_thw, dtype=np.int32)
660
661
        else:
            grid_thw_list = grid_thw.tolist()
662
            grid_thw = grid_thw.numpy()
663

664
        # compute position embedding
665
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
666
667

        # compute cu_seqlens
668
669
670
671
672
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
673
674
675

        # transformers
        x = x.unsqueeze(1)
676

677
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
678
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
679
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
680
        for blk in self.blocks:
681
682
683
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
684
685
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
686
687
                max_seqlen=max_seqlen,
            )
688
689
690

        # adapter
        x = self.merger(x)
691

692
693
        return x

694
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
695
696
697
698
699
700
701
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
702
        loaded_params: set[str] = set()
703
704

        for name, loaded_weight in weights:
705
            for param_name, weight_name, shard_id in stacked_params_mapping:
706
707
708
709
710
711
712
713
714
715
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
716
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
717
718
719
720
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

721

722
def _create_qwen2vl_field_factory(
723
    spatial_merge_size: int,
724
725
) -> Callable[
    [Mapping[str, torch.Tensor]],
726
    Mapping[str, MultiModalFieldConfig],
727
728
729
730
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
731
732
733
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
734
735
736

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
737
738
739
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
740
741
742

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
743
744
                "image", image_pixel_grid_sizes
            ),
745
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
746
747
                "image", image_embed_grid_sizes
            ),
748
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
749
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
750
751
                "video", video_grid_sizes
            ),
752
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
753
754
                "video", video_embed_grid_sizes
            ),
755
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
756
757
758
        )

    return _qwen2vl_field_config
759

760

Roger Wang's avatar
Roger Wang committed
761
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
762
763
764
765
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

766
767
    def _parse_image_data(
        self,
768
769
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
770
        if isinstance(data, dict):
771
772
773
774
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
775
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
776
            )
777
778
779
780

        return super()._parse_image_data(data)

    def _parse_video_data(
781
        self,
782
783
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
784
        if isinstance(data, dict):
785
786
787
788
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
789
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
790
            )
791
792
793
794

        return super()._parse_video_data(data)


795
796
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
797
798
        return self.ctx.get_hf_config(Qwen2VLConfig)

799
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
800
801
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
802
            use_fast=kwargs.pop("use_fast", True),
803
804
805
            **kwargs,
        )

806
807
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
808

809
810
811
812
813
814
    def get_data_parser(self):
        return Qwen2VLMultiModalDataParser(
            self.get_hf_config().vision_config.spatial_merge_size,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

815
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
816
817
        return {"image": None, "video": None}

818
819
820
821
822
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
823
824
825
826
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

827
828
829
830
831
832
833
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
834
        image_processor: Qwen2VLImageProcessor | None,
835
    ) -> tuple[ImageSize, int]:
836
837
838
839
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
840
        vision_config = hf_config.vision_config
841
842
843
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
844

845
846
847
848
849
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
850
851
                min_pixels=image_processor.size["shortest_edge"],
                max_pixels=image_processor.size["longest_edge"],
852
            )
853
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
854
        else:
855
            preprocessed_size = ImageSize(width=image_width, height=image_height)
856

857
858
859
860
861
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
862
863
864
865
866
867
868
869
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

870
    def get_num_image_tokens(
871
872
873
874
        self,
        *,
        image_width: int,
        image_height: int,
875
        image_processor: Qwen2VLImageProcessor | None,
876
877
878
879
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
880
            num_frames=1,
881
            image_processor=image_processor,
882
883
884
        )
        return num_image_tokens

885
    def get_num_video_tokens(
886
887
888
889
890
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
891
        image_processor: Qwen2VLImageProcessor | None,
892
893
894
895
896
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
897
            image_processor=image_processor,
898
899
900
        )
        return num_video_tokens

901
902
903
    def get_image_size_with_most_features(
        self, max_pixels: int | None = None
    ) -> ImageSize:
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
        # NOTE: Simply processing a huge size with _get_vision_info might not give a
        # size that maximizes the number of featrues, i.e., the number of (merged)
        # patches. This is because the number of patches limits the allowed aspect
        # ratios. For example, suppose the maximum number of patches is 1280. A square
        # image cannot be broken down into 1280 patches, so feeding a giant square image
        # into _get_vision_info will not yield a size that maximizes the number of
        # patches. Therefore, we directly factorize the maximum number of patches into
        # height and width. The tricky part is to avoid extreme aspect ratios (>200 for
        # qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
        # patches and retry. This is safe because the processor does not accept extreme
        # aspect ratios, so there is no valid post-resize image with the number of
        # patches that yields extreme aspect ratios.

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
921
922
        if max_pixels is None:
            image_processor = self.get_image_processor()
923
            max_pixels = image_processor.size["longest_edge"]
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        unit = patch_size * merge_size
        max_seq_len = max_pixels // (unit * unit)

        def closest_factor_pair(n: int) -> tuple[int, int]:
            # left <= right
            for d in range(math.isqrt(n), 0, -1):
                if n % d == 0:
                    return d, n // d
            return 1, n

        height_factor, width_factor = 1, max_seq_len
        for seq_len in range(max_seq_len, 0, -1):
            height_factor, width_factor = closest_factor_pair(seq_len)
            if width_factor / height_factor <= 200:
                break

        return ImageSize(width=unit * width_factor, height=unit * height_factor)
941

942
943
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
944

945
        return self.get_num_image_tokens(
946
947
            image_width=target_width,
            image_height=target_height,
948
            image_processor=None,
949
        )
950

951
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
952
        target_width, target_height = self.get_image_size_with_most_features()
953

954
        num_frames = start_num_frames
955
956
957

        while True:
            next_num_frames = num_frames + 1
958
            next_max_tokens = self.get_num_video_tokens(
959
960
961
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
962
                image_processor=None,
963
            )
964

965
            if next_max_tokens > max_tokens:
966
967
968
969
970
971
                break

            num_frames = next_num_frames

        return num_frames

972
973
974
975
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
976
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
977
978
    ) -> int:
        max_videos = mm_counts.get("video", 0)
979

980
        max_total_frames = self._get_max_video_frames(seq_len)
981
982
983
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
984

985
        return max(max_frames_per_video, 1)
986

987
988
989
990
991
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
992
        target_width, target_height = self.get_image_size_with_most_features()
993

994
        return self.get_num_video_tokens(
995
996
            image_width=target_width,
            image_height=target_height,
997
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
998
            image_processor=None,
999
1000
        )

1001
1002

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1003
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1004
1005
1006
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1007
        hf_processor = self.info.get_hf_processor()
1008
1009
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1010

1011
1012
1013
1014
1015
1016
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1017
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1018
1019
1020
1021
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1022
1023
1024
1025
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1026

1027
1028
1029
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1030
        return {
1031
1032
1033
1034
1035
1036
1037
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1038
1039
                width=target_width,
                height=target_height,
1040
                num_frames=target_num_frames,
1041
                num_videos=num_videos,
1042
                overrides=video_overrides,
1043
            ),
1044
1045
        }

1046

1047
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1048
    def _get_prompt_updates(
1049
1050
        self,
        mm_items: MultiModalDataItems,
1051
        hf_processor_mm_kwargs: Mapping[str, Any],
1052
        out_mm_kwargs: MultiModalKwargsItems,
1053
    ) -> Sequence[PromptUpdate]:
1054
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1055
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1056
1057
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1058
1059

        placeholder = {
1060
1061
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1062
        }
1063

1064
1065
1066
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1067
1068
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1069
1070
            assert isinstance(grid_thw, torch.Tensor)

1071
1072
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1073
1074
1075
1076

        return [
            PromptReplacement(
                modality=modality,
1077
                target=[placeholder[modality]],
1078
1079
1080
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1081
        ]
1082

1083
1084
1085
1086
1087
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1088
        return _create_qwen2vl_field_factory(
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1101
    # To ensure correct weight loading and mapping.
1102
1103
1104
1105
1106
1107
1108
1109
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1110
1111
        }
    )
1112

1113
1114
    supports_encoder_tp_data = True

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int, float]]:
        """
        Iterate over multimodal features and yield grid information.

        Args:
            mm_features: List of multimodal feature specifications

        Yields:
            Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
        """
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                second_per_grid_ts = 1.0
                if mm_feature.data.get("second_per_grid_ts", None):
                    second_per_grid_ts = mm_feature.data[
                        "second_per_grid_ts"
                    ].data.item()
                t_factor = second_per_grid_ts * tokens_per_second
                yield (
                    offset,
                    t,
                    h // spatial_merge_size,
                    w // spatial_merge_size,
                    t_factor,
                )
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

1153
1154
1155
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1156
        mm_features: list[MultiModalFeatureSpec],
1157
1158
1159
1160
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
        st = 0

1161
1162
1163
1164
1165
1166
1167
1168
        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
            t_factor,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
1169
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1170
            llm_pos_ids_list.append(
1171
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1172
            )
1173

1174
1175
1176
1177
1178
            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
            if t_factor != 1.0:
                grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
            llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
            st = offset + llm_grid_t * llm_grid_h * llm_grid_w
1179
1180

        if st < len(input_tokens):
1181
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1182
1183
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1184
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1185
            )
1186

1187
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
1188
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1189

1190
        return torch.from_numpy(llm_positions), mrope_position_delta
1191

1192
    @classmethod
1193
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1194
1195
1196
1197
1198
1199
1200
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1201
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1202
        super().__init__()
1203
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1204
1205
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1206

1207
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1208
1209
1210
        self.config = config
        self.multimodal_config = multimodal_config

1211
        with self._mark_tower_model(vllm_config, {"image", "video"}):
1212
1213
1214
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1215
                quant_config=quant_config,
1216
1217
                prefix=maybe_prefix(prefix, "visual"),
            )
1218

1219
1220
1221
1222
1223
1224
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
1225

1226
        self.make_empty_intermediate_tensors = (
1227
1228
            self.language_model.make_empty_intermediate_tensors
        )
1229
1230

    def _parse_and_validate_image_input(
1231
        self, **kwargs: object
1232
    ) -> Qwen2VLImageInputs | None:
1233
        pixel_values = kwargs.pop("pixel_values", None)
1234
        image_embeds = kwargs.pop("image_embeds", None)
1235
1236
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1237
        if pixel_values is None and image_embeds is None:
1238
1239
            return None

1240
        if pixel_values is not None:
1241
1242
1243
1244
1245
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1246
1247

        if image_embeds is not None:
1248
1249
1250
1251
1252
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1253
1254

    def _parse_and_validate_video_input(
1255
        self, **kwargs: object
1256
    ) -> Qwen2VLVideoInputs | None:
1257
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1258
        video_embeds = kwargs.pop("video_embeds", None)
1259
1260
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1261
        if pixel_values_videos is None and video_embeds is None:
1262
1263
            return None

1264
1265
1266
1267
1268
1269
1270
1271
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1272
1273
1274
1275
1276
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1277

1278
    def _process_image_input(
1279
1280
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1281
1282
1283
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1284
        if image_input["type"] == "image_embeds":
1285
            image_embeds = image_input["image_embeds"]
1286
        else:
1287
            pixel_values = image_input["pixel_values"]
1288
1289

            if self.use_data_parallel:
1290
                return run_dp_sharded_mrope_vision_model(
1291
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1292
                )
1293
            else:
1294
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1295
1296
1297

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1298
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1299
        return image_embeds.split(sizes)
1300
1301

    def _process_video_input(
1302
1303
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1304
1305
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1306

1307
        if video_input["type"] == "video_embeds":
1308
            video_embeds = video_input["video_embeds"]
1309
        else:
1310
            pixel_values_videos = video_input["pixel_values_videos"]
1311
            if self.use_data_parallel:
1312
                return run_dp_sharded_mrope_vision_model(
1313
1314
1315
1316
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
1317
                )
1318
            else:
1319
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1320

1321
1322
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1323
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1324
        return video_embeds.split(sizes)
1325
1326
1327
1328
1329
1330
1331

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1342
1343

        return modalities
1344

1345
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1346
1347
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1348
            return []
1349

1350
1351
1352
1353
1354
1355
1356
1357
1358
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1359
1360
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1361
1362
1363
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1364
                multimodal_embeddings += tuple(video_embeddings)
1365
1366
1367

        return multimodal_embeddings

1368
1369
    def forward(
        self,
1370
        input_ids: torch.Tensor | None,
1371
        positions: torch.Tensor,
1372
1373
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1374
        **kwargs: object,
1375
    ) -> torch.Tensor | IntermediateTensors:
1376
1377
1378
1379
1380
1381
1382
1383
1384
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1385
1386
1387
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1388
        """
1389

1390
        if intermediate_tensors is not None:
1391
            inputs_embeds = None
1392

1393
        hidden_states = self.language_model.model(
1394
1395
            input_ids=input_ids,
            positions=positions,
1396
            intermediate_tensors=intermediate_tensors,
1397
1398
1399
1400
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1401
1402
1403
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1404
    ) -> torch.Tensor | None:
1405
        return self.language_model.compute_logits(hidden_states)
1406

1407
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1408
        loader = AutoWeightsLoader(self)
1409
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1410
1411
1412
1413
1414
1415
1416

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1417
1418
1419
            connector="visual.merger.",
            tower_model="visual.",
        )
1420

1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

1440
1441
1442
1443
1444
1445
1446
1447

class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1448
        size: dict[str, int] | None = None,
1449
1450
1451
1452
1453
1454
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1455
                "longest_edge": size["max_pixels"],
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
1466
        tokenizer: TokenizerLike,
1467
1468
1469
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1470
1471
1472
1473
1474
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1475
1476
            **kwargs,
        )
1477
1478
1479
1480
1481


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1482
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1494
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1495
1496


1497
1498
1499
1500
1501
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1502
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1503
1504
1505
1506
1507
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1508

1509
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1510
1511
1512
1513
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1514
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)