qwen2_vl.py 51.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32

33
import numpy as np
34
35
import torch
import torch.nn as nn
36
from einops import rearrange
37
from transformers import BatchFeature
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.config import VllmConfig
47
from vllm.config.multimodal import BaseDummyOptions
48
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
49
50
51
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
52
from vllm.model_executor.layers.attention import MMEncoderAttention
53
from vllm.model_executor.layers.conv import Conv3dLayer
54
55
56
57
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
58
from vllm.model_executor.layers.quantization import QuantizationConfig
59
from vllm.model_executor.layers.rotary_embedding import get_rope
60
from vllm.model_executor.layers.rotary_embedding.common import (
61
    ApplyRotaryEmb,
62
)
63
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
64
from vllm.model_executor.models.module_mapping import MultiModelKeys
65
from vllm.multimodal import MULTIMODAL_REGISTRY
66
67
68
69
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
70
    MultiModalFeatureSpec,
71
72
73
74
75
76
77
78
79
80
81
82
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
83
    BaseDummyInputsBuilder,
84
85
86
87
88
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
89
from vllm.sequence import IntermediateTensors
90
from vllm.tokenizers import TokenizerLike
91
from vllm.utils.tensor_schema import TensorSchema, TensorShape
92
from vllm.v1.attention.backends.registry import AttentionBackendEnum
93

94
95
96
97
98
99
100
101
102
103
104
105
106
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
107
108
from .vision import (
    get_vit_attn_backend,
109
    is_vit_use_data_parallel,
110
111
    run_dp_sharded_mrope_vision_model,
)
112

113
114
logger = init_logger(__name__)

115
# For profile run
116
_MAX_FRAMES_PER_VIDEO = 14
117

118
119
120
# === Vision Inputs === #


121
class Qwen2VLImagePixelInputs(TensorSchema):
122
    """
123
124
125
126
127
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
128

129
    Historical context:
130
        - pixel_values shape: (num_patches, num_channels * patch_size *
131
132
133
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
134
    """
135

136
    type: Literal["pixel_values"]
137

138
139
140
141
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
142

143
144
145
146
147
148
149
150
151
152
153
154
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
155

156
157
158
159
160
161
162
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
163
    """
164

165
    type: Literal["image_embeds"]
166

167
168
169
170
171
172
173
174
175
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
176
177


178
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
179
180


181
182
183
184
185
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
186
        - ctps: Number of channels * temporal_patch_size * patch_size *
187
188
          patch_size
        - nv: Number of videos
189

190
    Historical context:
191
        - pixel_values_videos shape: (num_patches, num_channels *
192
193
194
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
195
    """
196

197
    type: Literal["pixel_values_videos"]
198

199
200
201
202
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
203

204
205
206
207
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
208
209


210
211
212
213
214
215
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
216

217
218
219
220
221
222
223
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
224
    """
225

226
    type: Literal["video_embeds"]
227

228
229
230
231
232
233
234
235
236
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
237
238


239
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
240

241
242
243
244
245
246
247
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
248
        hidden_features: int,
249
        act_layer: type[nn.Module] = QuickGELU,
250
        quant_config: QuantizationConfig | None = None,
251
        prefix: str = "",
252
253
    ):
        super().__init__()
254
        use_data_parallel = is_vit_use_data_parallel()
255
256
257
258
259
260
261
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
262
        self.act = act_layer()
263
264
265
266
267
268
269
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
270
271
272
273
274
275
276
277
278
279
280

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
281
282
283
        embed_dim: int,
        num_heads: int,
        projection_size: int,
284
        quant_config: QuantizationConfig | None = None,
285
        prefix: str = "",
286
287
288
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
289
        use_data_parallel = is_vit_use_data_parallel()
290
291
292
293
294
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
295
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
296
        self.hidden_size_per_attention_head = dist_utils.divide(
297
298
            projection_size, num_heads
        )
299
        self.num_attention_heads_per_partition = dist_utils.divide(
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
317

318
319
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
320
            head_size=self.hidden_size_per_attention_head,
321
            scale=self.hidden_size_per_attention_head**-0.5,
322
            prefix=f"{prefix}.attn",
323
        )
324

325
326
        self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

327
328
329
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
330
331
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)
332
333
334
335

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

336
337
338
339
340
341
342
343
344
        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

345
        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
346
347
348
349
350
351
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
352
353
354
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

355
    def forward(
356
357
358
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
359
360
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
361
        max_seqlen: int | None = None,  # Only used for Flash Attention
362
    ) -> torch.Tensor:
363
364
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
365

366
367
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
368

369
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
370
371
372

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
373
374
375
376
        qk_rotated = self.apply_rotary_emb(
            qk_concat,
            rotary_pos_emb_cos,
            rotary_pos_emb_sin,
377
378
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
379

380
381
382
383
384
385
386
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
387

388
        context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
389
390
391
392
393
394
395
396
397
398
399

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
400
        act_layer: type[nn.Module] = QuickGELU,
401
402
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
403
        prefix: str = "",
404
405
406
407
408
409
410
411
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

412
413
414
415
416
417
418
419
420
421
422
423
424
425
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
426

427
    def forward(
428
429
430
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
431
432
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
433
        max_seqlen: int | None = None,  # Only used for Flash Attention
434
435
436
437
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
438
439
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
440
441
442
            max_seqlen=max_seqlen,
        )

443
444
445
446
447
448
449
450
451
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
452
        in_channels: int = 3,
453
454
455
456
457
458
459
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

460
        kernel_size = (temporal_patch_size, patch_size, patch_size)
461
462
        self.proj = Conv3dLayer(
            in_channels,
463
            embed_dim,
464
465
            kernel_size=kernel_size,
            stride=kernel_size,
466
467
            bias=False,
        )
468
469

    def forward(self, x: torch.Tensor) -> torch.Tensor:
470
471
472
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
473
474
475
476
477
478
479
480
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
481
        norm_layer: Callable[[int], nn.Module] | None = None,
482
        spatial_merge_size: int = 2,
483
        quant_config: QuantizationConfig | None = None,
484
        prefix: str = "",
485
486
    ) -> None:
        super().__init__()
487
        use_data_parallel = is_vit_use_data_parallel()
488
489
490
491
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
530
        quant_config: QuantizationConfig | None = None,
531
        prefix: str = "",
532
533
534
    ) -> None:
        super().__init__()

535
536
537
538
539
540
541
542
543
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
544

545
        self.use_data_parallel = is_vit_use_data_parallel()
546
547
        self.out_hidden_size = vision_config.hidden_size

548
        self.spatial_merge_size = spatial_merge_size
549
550
        self.num_heads = num_heads
        self.embed_dim = embed_dim
551
552
553
554

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
555
            in_channels=in_channels,
556
557
558
559
560
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
561
562
563
564
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
565
            rope_parameters={"partial_rotary_factor": 0.5},
566
        )
567

568
569
570
571
572
573
574
575
576
577
578
579
580
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(depth)
            ]
        )
581
582
583
584
585
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
586
            prefix=f"{prefix}.merger",
587
        )
588
        self.attn_backend = get_vit_attn_backend(
589
590
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
591
        )
592
593
594

    @property
    def dtype(self) -> torch.dtype:
595
        return self.patch_embed.proj.weight.dtype
596
597
598

    @property
    def device(self) -> torch.device:
599
        return self.patch_embed.proj.weight.device
600

601
602
603
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
604
        pos_ids = []
605
        max_grid_size = 0
606
607
608
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
630
            max_grid_size = max(max_grid_size, h, w)
631
        pos_ids = torch.cat(pos_ids, dim=0)
632
633
634
635

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

636
637
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
638
        return cos_combined, sin_combined
639

640
641
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
642
643
644
645
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
646
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
647
        return max_seqlen
648

649
650
651
    def forward(
        self,
        x: torch.Tensor,
652
        grid_thw: torch.Tensor | list[list[int]],
653
654
655
656
657
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

658
659
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
660
            grid_thw = np.array(grid_thw, dtype=np.int32)
661
662
        else:
            grid_thw_list = grid_thw.tolist()
663
            grid_thw = grid_thw.numpy()
664

665
        # compute position embedding
666
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
667
668

        # compute cu_seqlens
669
670
671
672
673
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
674
675
676

        # transformers
        x = x.unsqueeze(1)
677

678
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
679
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
680
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
681
        for blk in self.blocks:
682
683
684
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
685
686
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
687
688
                max_seqlen=max_seqlen,
            )
689
690
691

        # adapter
        x = self.merger(x)
692

693
694
        return x

695
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
696
697
698
699
700
701
702
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
703
        loaded_params: set[str] = set()
704
705

        for name, loaded_weight in weights:
706
            for param_name, weight_name, shard_id in stacked_params_mapping:
707
708
709
710
711
712
713
714
715
716
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
717
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
718
719
720
721
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

722

723
def _create_qwen2vl_field_factory(
724
    spatial_merge_size: int,
725
726
) -> Callable[
    [Mapping[str, torch.Tensor]],
727
    Mapping[str, MultiModalFieldConfig],
728
729
730
731
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
732
733
734
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
735
736
737

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
738
739
740
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
741
742
743

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
744
745
                "image", image_pixel_grid_sizes
            ),
746
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
747
748
                "image", image_embed_grid_sizes
            ),
749
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
750
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
751
752
                "video", video_grid_sizes
            ),
753
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
754
755
                "video", video_embed_grid_sizes
            ),
756
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
757
758
759
        )

    return _qwen2vl_field_config
760

761

Roger Wang's avatar
Roger Wang committed
762
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
763
764
765
766
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

767
768
    def _parse_image_data(
        self,
769
770
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
771
        if isinstance(data, dict):
772
773
774
775
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
776
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
777
            )
778
779
780
781

        return super()._parse_image_data(data)

    def _parse_video_data(
782
        self,
783
784
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
785
        if isinstance(data, dict):
786
787
788
789
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
790
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
791
            )
792
793
794
795

        return super()._parse_video_data(data)


796
797
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
798
799
        return self.ctx.get_hf_config(Qwen2VLConfig)

800
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
801
802
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
803
            use_fast=kwargs.pop("use_fast", True),
804
805
806
            **kwargs,
        )

807
808
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
809

810
811
812
813
814
815
    def get_data_parser(self):
        return Qwen2VLMultiModalDataParser(
            self.get_hf_config().vision_config.spatial_merge_size,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

816
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
817
818
        return {"image": None, "video": None}

819
820
821
822
823
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
824
825
826
827
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

828
829
830
831
832
833
834
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
835
        image_processor: Qwen2VLImageProcessor | None,
836
    ) -> tuple[ImageSize, int]:
837
838
839
840
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
841
        vision_config = hf_config.vision_config
842
843
844
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
845

846
847
848
849
850
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
851
852
                min_pixels=image_processor.size["shortest_edge"],
                max_pixels=image_processor.size["longest_edge"],
853
            )
854
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
855
        else:
856
            preprocessed_size = ImageSize(width=image_width, height=image_height)
857

858
859
860
861
862
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
863
864
865
866
867
868
869
870
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

871
    def get_num_image_tokens(
872
873
874
875
        self,
        *,
        image_width: int,
        image_height: int,
876
        image_processor: Qwen2VLImageProcessor | None,
877
878
879
880
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
881
            num_frames=1,
882
            image_processor=image_processor,
883
884
885
        )
        return num_image_tokens

886
    def get_num_video_tokens(
887
888
889
890
891
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
892
        image_processor: Qwen2VLImageProcessor | None,
893
894
895
896
897
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
898
            image_processor=image_processor,
899
900
901
        )
        return num_video_tokens

902
903
904
    def get_image_size_with_most_features(
        self, max_pixels: int | None = None
    ) -> ImageSize:
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
        # NOTE: Simply processing a huge size with _get_vision_info might not give a
        # size that maximizes the number of featrues, i.e., the number of (merged)
        # patches. This is because the number of patches limits the allowed aspect
        # ratios. For example, suppose the maximum number of patches is 1280. A square
        # image cannot be broken down into 1280 patches, so feeding a giant square image
        # into _get_vision_info will not yield a size that maximizes the number of
        # patches. Therefore, we directly factorize the maximum number of patches into
        # height and width. The tricky part is to avoid extreme aspect ratios (>200 for
        # qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
        # patches and retry. This is safe because the processor does not accept extreme
        # aspect ratios, so there is no valid post-resize image with the number of
        # patches that yields extreme aspect ratios.

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
922
923
        if max_pixels is None:
            image_processor = self.get_image_processor()
924
            max_pixels = image_processor.size["longest_edge"]
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        unit = patch_size * merge_size
        max_seq_len = max_pixels // (unit * unit)

        def closest_factor_pair(n: int) -> tuple[int, int]:
            # left <= right
            for d in range(math.isqrt(n), 0, -1):
                if n % d == 0:
                    return d, n // d
            return 1, n

        height_factor, width_factor = 1, max_seq_len
        for seq_len in range(max_seq_len, 0, -1):
            height_factor, width_factor = closest_factor_pair(seq_len)
            if width_factor / height_factor <= 200:
                break

        return ImageSize(width=unit * width_factor, height=unit * height_factor)
942

943
944
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
945

946
        return self.get_num_image_tokens(
947
948
            image_width=target_width,
            image_height=target_height,
949
            image_processor=None,
950
        )
951

952
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
953
        target_width, target_height = self.get_image_size_with_most_features()
954

955
        num_frames = start_num_frames
956
957
958

        while True:
            next_num_frames = num_frames + 1
959
            next_max_tokens = self.get_num_video_tokens(
960
961
962
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
963
                image_processor=None,
964
            )
965

966
            if next_max_tokens > max_tokens:
967
968
969
970
971
972
                break

            num_frames = next_num_frames

        return num_frames

973
974
975
976
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
977
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
978
979
    ) -> int:
        max_videos = mm_counts.get("video", 0)
980

981
        max_total_frames = self._get_max_video_frames(seq_len)
982
983
984
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
985

986
        return max(max_frames_per_video, 1)
987

988
989
990
991
992
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
993
        target_width, target_height = self.get_image_size_with_most_features()
994

995
        return self.get_num_video_tokens(
996
997
            image_width=target_width,
            image_height=target_height,
998
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
999
            image_processor=None,
1000
1001
        )

1002
1003

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1004
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1005
1006
1007
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1008
        hf_processor = self.info.get_hf_processor()
1009
1010
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1011

1012
1013
1014
1015
1016
1017
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1018
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1019
        mm_processor_kwargs: Mapping[str, object] | None = None,
1020
1021
1022
1023
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1024
1025
1026
1027
        mm_processor_kwargs = mm_processor_kwargs or {}
        target_width, target_height = self.info.get_image_size_with_most_features(
            max_pixels=mm_processor_kwargs.get("max_pixels", None)
        )
1028
1029
1030
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1031

1032
1033
1034
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1035
        return {
1036
1037
1038
1039
1040
1041
1042
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1043
1044
                width=target_width,
                height=target_height,
1045
                num_frames=target_num_frames,
1046
                num_videos=num_videos,
1047
                overrides=video_overrides,
1048
            ),
1049
1050
        }

1051

1052
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1053
    def _get_prompt_updates(
1054
1055
        self,
        mm_items: MultiModalDataItems,
1056
        hf_processor_mm_kwargs: Mapping[str, Any],
1057
        out_mm_kwargs: MultiModalKwargsItems,
1058
    ) -> Sequence[PromptUpdate]:
1059
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1060
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1061
1062
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1063
1064

        placeholder = {
1065
1066
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1067
        }
1068

1069
1070
1071
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1072
1073
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1074
1075
            assert isinstance(grid_thw, torch.Tensor)

1076
1077
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1078
1079
1080
1081

        return [
            PromptReplacement(
                modality=modality,
1082
                target=[placeholder[modality]],
1083
1084
1085
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1086
        ]
1087

1088
1089
1090
1091
1092
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1093
        return _create_qwen2vl_field_factory(
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1106
    # To ensure correct weight loading and mapping.
1107
1108
1109
1110
1111
1112
1113
1114
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1115
1116
        }
    )
1117

1118
1119
    supports_encoder_tp_data = True

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int, float]]:
        """
        Iterate over multimodal features and yield grid information.

        Args:
            mm_features: List of multimodal feature specifications

        Yields:
            Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
        """
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                second_per_grid_ts = 1.0
                if mm_feature.data.get("second_per_grid_ts", None):
                    second_per_grid_ts = mm_feature.data[
                        "second_per_grid_ts"
                    ].data.item()
                t_factor = second_per_grid_ts * tokens_per_second
                yield (
                    offset,
                    t,
                    h // spatial_merge_size,
                    w // spatial_merge_size,
                    t_factor,
                )
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

1158
1159
1160
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1161
        mm_features: list[MultiModalFeatureSpec],
1162
1163
1164
1165
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
        st = 0

1166
1167
1168
1169
1170
1171
1172
1173
        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
            t_factor,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
1174
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1175
            llm_pos_ids_list.append(
1176
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1177
            )
1178

1179
1180
1181
1182
1183
            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
            if t_factor != 1.0:
                grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
            llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
            st = offset + llm_grid_t * llm_grid_h * llm_grid_w
1184
1185

        if st < len(input_tokens):
1186
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1187
1188
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1189
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
1190
            )
1191

1192
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
1193
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1194

1195
        return torch.from_numpy(llm_positions), mrope_position_delta
1196

1197
    @classmethod
1198
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1199
1200
1201
1202
1203
1204
1205
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1206
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1207
        super().__init__()
1208
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1209
1210
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1211

1212
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1213
1214
1215
        self.config = config
        self.multimodal_config = multimodal_config

1216
        with self._mark_tower_model(vllm_config, {"image", "video"}):
1217
1218
1219
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1220
                quant_config=quant_config,
1221
1222
                prefix=maybe_prefix(prefix, "visual"),
            )
1223

1224
1225
1226
1227
1228
1229
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
1230

1231
        self.make_empty_intermediate_tensors = (
1232
1233
            self.language_model.make_empty_intermediate_tensors
        )
1234
1235

    def _parse_and_validate_image_input(
1236
        self, **kwargs: object
1237
    ) -> Qwen2VLImageInputs | None:
1238
        pixel_values = kwargs.pop("pixel_values", None)
1239
        image_embeds = kwargs.pop("image_embeds", None)
1240
1241
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1242
        if pixel_values is None and image_embeds is None:
1243
1244
            return None

1245
        if pixel_values is not None:
1246
1247
1248
1249
1250
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1251
1252

        if image_embeds is not None:
1253
1254
1255
1256
1257
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1258
1259

    def _parse_and_validate_video_input(
1260
        self, **kwargs: object
1261
    ) -> Qwen2VLVideoInputs | None:
1262
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1263
        video_embeds = kwargs.pop("video_embeds", None)
1264
1265
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1266
        if pixel_values_videos is None and video_embeds is None:
1267
1268
            return None

1269
1270
1271
1272
1273
1274
1275
1276
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1277
1278
1279
1280
1281
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1282

1283
    def _process_image_input(
1284
1285
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1286
1287
1288
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1289
        if image_input["type"] == "image_embeds":
1290
            image_embeds = image_input["image_embeds"]
1291
        else:
1292
            pixel_values = image_input["pixel_values"]
1293
1294

            if self.use_data_parallel:
1295
                return run_dp_sharded_mrope_vision_model(
1296
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1297
                )
1298
            else:
1299
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1300
1301
1302

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1303
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1304
        return image_embeds.split(sizes)
1305
1306

    def _process_video_input(
1307
1308
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1309
1310
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1311

1312
        if video_input["type"] == "video_embeds":
1313
            video_embeds = video_input["video_embeds"]
1314
        else:
1315
            pixel_values_videos = video_input["pixel_values_videos"]
1316
            if self.use_data_parallel:
1317
                return run_dp_sharded_mrope_vision_model(
1318
1319
1320
1321
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
1322
                )
1323
            else:
1324
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1325

1326
1327
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1328
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1329
        return video_embeds.split(sizes)
1330
1331
1332
1333
1334
1335
1336

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1347
1348

        return modalities
1349

1350
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1351
1352
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1353
            return []
1354

1355
1356
1357
1358
1359
1360
1361
1362
1363
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1364
1365
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1366
1367
1368
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1369
                multimodal_embeddings += tuple(video_embeddings)
1370
1371
1372

        return multimodal_embeddings

1373
1374
    def forward(
        self,
1375
        input_ids: torch.Tensor | None,
1376
        positions: torch.Tensor,
1377
1378
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1379
        **kwargs: object,
1380
    ) -> torch.Tensor | IntermediateTensors:
1381
1382
1383
1384
1385
1386
1387
1388
1389
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1390
1391
1392
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1393
        """
1394

1395
        if intermediate_tensors is not None:
1396
            inputs_embeds = None
1397

1398
        hidden_states = self.language_model.model(
1399
1400
            input_ids=input_ids,
            positions=positions,
1401
            intermediate_tensors=intermediate_tensors,
1402
1403
1404
1405
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1406
1407
1408
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1409
    ) -> torch.Tensor | None:
1410
        return self.language_model.compute_logits(hidden_states)
1411

1412
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1413
        loader = AutoWeightsLoader(self)
1414
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1415
1416
1417
1418
1419
1420
1421

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1422
1423
1424
            connector="visual.merger.",
            tower_model="visual.",
        )
1425

1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

1445
1446
1447
1448
1449
1450
1451
1452

class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1453
        size: dict[str, int] | None = None,
1454
1455
1456
1457
1458
1459
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1460
                "longest_edge": size["max_pixels"],
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
1471
        tokenizer: TokenizerLike,
1472
1473
1474
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1475
1476
1477
1478
1479
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1480
1481
            **kwargs,
        )
1482
1483
1484
1485
1486


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1487
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1499
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1500
1501


1502
1503
1504
1505
1506
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1507
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1508
1509
1510
1511
1512
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1513

1514
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1515
1516
1517
1518
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1519
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)