qwen2_vl.py 55.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31
32
33
34

import torch
import torch.nn as nn
import torch.nn.functional as F
35
from einops import rearrange
36
from transformers import BatchFeature
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
38
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
39
40
41
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
42
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
43
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
44

45
from vllm.attention.backends.registry import AttentionBackendEnum
46
47
48
49
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.conv import Conv3dLayer
57
58
59
60
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
from vllm.model_executor.layers.rotary_embedding import get_rope
63
from vllm.model_executor.layers.rotary_embedding.common import (
64
    apply_rotary_emb_torch,
65
66
    dispatch_rotary_emb_function,
)
67
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
68
from vllm.model_executor.models.module_mapping import MultiModelKeys
69
from vllm.multimodal import MULTIMODAL_REGISTRY
70
71
72
73
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
74
    MultiModalFeatureSpec,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
92
from vllm.multimodal.profiling import BaseDummyInputsBuilder
93
from vllm.sequence import IntermediateTensors
94
from vllm.transformers_utils.tokenizer import AnyTokenizer
95
from vllm.utils.tensor_schema import TensorSchema, TensorShape
96

97
98
99
100
101
102
103
104
105
106
107
108
109
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
110
111
112
113
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
114

115
116
logger = init_logger(__name__)

117
# For profile run
118
_MAX_FRAMES_PER_VIDEO = 14
119

120
121
122
# === Vision Inputs === #


123
class Qwen2VLImagePixelInputs(TensorSchema):
124
    """
125
126
127
128
129
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
130

131
    Historical context:
132
        - pixel_values shape: (num_patches, num_channels * patch_size *
133
134
135
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
136
    """
137

138
    type: Literal["pixel_values"]
139

140
141
142
143
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
144

145
146
147
148
149
150
151
152
153
154
155
156
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
157

158
159
160
161
162
163
164
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
165
    """
166

167
    type: Literal["image_embeds"]
168

169
170
171
172
173
174
175
176
177
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
178
179


180
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
181
182


183
184
185
186
187
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
188
        - ctps: Number of channels * temporal_patch_size * patch_size *
189
190
          patch_size
        - nv: Number of videos
191

192
    Historical context:
193
        - pixel_values_videos shape: (num_patches, num_channels *
194
195
196
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
197
    """
198

199
    type: Literal["pixel_values_videos"]
200

201
202
203
204
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
205

206
207
208
209
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
210
211


212
213
214
215
216
217
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
218

219
220
221
222
223
224
225
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
226
    """
227

228
    type: Literal["video_embeds"]
229

230
231
232
233
234
235
236
237
238
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
239
240


241
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
242

243
244
245
246
247
248
249
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
250
        hidden_features: int,
251
        act_layer: type[nn.Module] = QuickGELU,
252
        quant_config: QuantizationConfig | None = None,
253
        prefix: str = "",
254
        use_data_parallel: bool = False,
255
256
    ):
        super().__init__()
257
258
259
260
261
262
263
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
264
        self.act = act_layer()
265
266
267
268
269
270
271
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
272
273
274
275
276
277
278
279

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


280
281
def apply_rotary_pos_emb_vision(
    t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
282
) -> torch.Tensor:
283
284
    rotary_emb_function = dispatch_rotary_emb_function(
        default=partial(apply_rotary_emb_torch, is_neox_style=True)
285
    )
286
    output = rotary_emb_function(t, cos, sin).type_as(t)
287
288
289
290
291
292
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
293
294
295
        embed_dim: int,
        num_heads: int,
        projection_size: int,
296
        quant_config: QuantizationConfig | None = None,
297
        prefix: str = "",
298
        use_data_parallel: bool = False,
299
        attn_backend_override: AttentionBackendEnum | None = None,
300
301
302
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
303
304
305
306
307
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
308
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
309
        self.hidden_size_per_attention_head = dist_utils.divide(
310
311
            projection_size, num_heads
        )
312
        self.num_attention_heads_per_partition = dist_utils.divide(
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
330
331

        # Detect attention implementation.
332
333
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
334
            dtype=torch.get_default_dtype(),
335
            attn_backend_override=attn_backend_override,
336
        )
337
        self.use_upstream_fa = False
338

339
340
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
341
342
                self.attn_backend,
                self.use_upstream_fa,
343
                attn_backend_override=attn_backend_override,
344
            )
345
        )
346

347
        if self.attn_backend not in {
348
349
350
351
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.XFORMERS,
            AttentionBackendEnum.ROCM_AITER_FA,
352
353
        }:
            raise RuntimeError(
354
355
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
356

357
        self.is_flash_attn_backend = self.attn_backend in {
358
359
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
360
        }
361

362
363
364
365
366
367
368
369
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
370
371
372
373
374
375
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
376
377
378
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

379
    def forward(
380
381
382
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
383
384
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
385
386
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
387
    ) -> torch.Tensor:
388
389
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
390

391
392
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
393
394
        batch_size = q.shape[1]

395
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
396
397
398
399
400
401
402

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(
            qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
403

404
        if self.is_flash_attn_backend:
405
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
422
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
423
            # Execute attention entry by entry for speed & less VRAM.
424
425
426
427
428
429
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
430
            outputs = []
431
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
432
433
434
435
436
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
437
438
439
440
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
441
442
443
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
444
445
446
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
447
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
448
449
450
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

451
452
453
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
454
455

            context_layer = xops.memory_efficient_attention_forward(
456
457
458
459
460
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
461
462
463
464
465
466
467
468
469
470
471

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
472
        act_layer: type[nn.Module] = QuickGELU,
473
474
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
475
        prefix: str = "",
476
        use_data_parallel: bool = False,
477
        attn_backend_override: AttentionBackendEnum | None = None,
478
479
480
481
482
483
484
485
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

486
487
488
489
490
491
492
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
493
            attn_backend_override=attn_backend_override,
494
495
496
497
498
499
500
501
502
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
503

504
    def forward(
505
506
507
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
508
509
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
510
511
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
512
513
514
515
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
516
517
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
518
519
520
521
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

522
523
524
525
526
527
528
529
530
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
531
        in_channels: int = 3,
532
533
534
535
536
537
538
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

539
        kernel_size = (temporal_patch_size, patch_size, patch_size)
540
541
        self.proj = Conv3dLayer(
            in_channels,
542
            embed_dim,
543
544
            kernel_size=kernel_size,
            stride=kernel_size,
545
546
            bias=False,
        )
547
548

    def forward(self, x: torch.Tensor) -> torch.Tensor:
549
550
551
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
552
553
554
555
556
557
558
559
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
560
        norm_layer: Callable[[int], nn.Module] | None = None,
561
        spatial_merge_size: int = 2,
562
        quant_config: QuantizationConfig | None = None,
563
        prefix: str = "",
564
        use_data_parallel: bool = False,
565
566
567
568
569
570
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
609
        quant_config: QuantizationConfig | None = None,
610
        prefix: str = "",
611
        use_data_parallel: bool = False,
612
        attn_backend_override: AttentionBackendEnum | None = None,
613
614
615
    ) -> None:
        super().__init__()

616
617
618
619
620
621
622
623
624
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
625

626
627
628
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

629
        self.spatial_merge_size = spatial_merge_size
630
631
        self.num_heads = num_heads
        self.embed_dim = embed_dim
632
633
634
635

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
636
            in_channels=in_channels,
637
638
639
640
641
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
642
643
644
645
646
647
648
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            rotary_dim=head_dim // 2,
            max_position=8192,
            base=10000.0,
            is_neox_style=True,
        )
649

650
651
652
653
654
655
656
657
658
659
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
660
                    attn_backend_override=attn_backend_override,
661
662
663
664
                )
                for layer_idx in range(depth)
            ]
        )
665
666
667
668
669
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
670
            prefix=f"{prefix}.merger",
671
            use_data_parallel=use_data_parallel,
672
        )
673
        self.attn_backend = get_vit_attn_backend(
674
675
676
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
677
        )
678
679
680
        if (
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
681
        ):
682
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
683
684
685

    @property
    def dtype(self) -> torch.dtype:
686
        return self.patch_embed.proj.weight.dtype
687
688
689

    @property
    def device(self) -> torch.device:
690
        return self.patch_embed.proj.weight.device
691

692
693
694
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
695
        pos_ids = []
696
        max_grid_size = 0
697
698
699
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
721
            max_grid_size = max(max_grid_size, h, w)
722
        pos_ids = torch.cat(pos_ids, dim=0)
723
724
725
726

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

727
728
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
729
        return cos_combined, sin_combined
730

731
    def compute_attn_mask_seqlen(
732
        self, cu_seqlens: torch.Tensor
733
    ) -> tuple[int | None, list[int] | None]:
734
        max_seqlen, seqlens = None, None
735
736
737
738
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
739
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
740
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
741
742
743
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

744
745
746
    def forward(
        self,
        x: torch.Tensor,
747
        grid_thw: torch.Tensor | list[list[int]],
748
749
750
751
752
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

753
754
755
756
757
758
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
        else:
            grid_thw_list = grid_thw.tolist()

759
        # compute position embedding
760
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
761
762

        # compute cu_seqlens
763
        cu_seqlens = torch.repeat_interleave(
764
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
765
        ).cumsum(dim=0, dtype=torch.int32)
766
767
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
768
769
770

        # transformers
        x = x.unsqueeze(1)
771

772
773
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
774
        for blk in self.blocks:
775
776
777
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
778
779
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
780
781
782
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
783
784
785

        # adapter
        x = self.merger(x)
786

787
788
        return x

789
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
790
791
792
793
794
795
796
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
797
        loaded_params: set[str] = set()
798
799

        for name, loaded_weight in weights:
800
            for param_name, weight_name, shard_id in stacked_params_mapping:
801
802
803
804
805
806
807
808
809
810
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
811
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
812
813
814
815
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

816

817
def _create_qwen2vl_field_factory(
818
    spatial_merge_size: int,
819
820
) -> Callable[
    [Mapping[str, torch.Tensor]],
821
    Mapping[str, MultiModalFieldConfig],
822
823
824
825
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
826
827
828
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
829
830
831

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
832
833
834
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
835
836
837

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
838
839
                "image", image_pixel_grid_sizes
            ),
840
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
841
842
                "image", image_embed_grid_sizes
            ),
843
844
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
845
846
                "video", video_grid_sizes
            ),
847
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
848
849
                "video", video_embed_grid_sizes
            ),
850
851
852
853
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
854

855

Roger Wang's avatar
Roger Wang committed
856
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
857
858
859
860
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

861
862
    def _parse_image_data(
        self,
863
864
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
865
        if isinstance(data, dict):
866
867
868
869
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
870
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
871
            )
872
873
874
875

        return super()._parse_image_data(data)

    def _parse_video_data(
876
        self,
877
878
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
879
        if isinstance(data, dict):
880
881
882
883
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
884
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
885
            )
886
887
888
889

        return super()._parse_video_data(data)


890
891
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
892
893
        return self.ctx.get_hf_config(Qwen2VLConfig)

894
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
895
896
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
897
            use_fast=kwargs.pop("use_fast", True),
898
899
900
            **kwargs,
        )

901
902
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
903

904
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
905
906
        return {"image": None, "video": None}

907
908
909
910
911
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
912
913
914
915
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

916
917
918
919
920
921
922
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
923
        image_processor: Qwen2VLImageProcessor | None,
924
    ) -> tuple[ImageSize, int]:
925
926
927
928
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
929
        vision_config = hf_config.vision_config
930
931
932
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
933

934
935
936
937
938
939
940
941
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
942
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
943
        else:
944
            preprocessed_size = ImageSize(width=image_width, height=image_height)
945

946
947
948
949
950
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
951
952
953
954
955
956
957
958
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

959
    def get_num_image_tokens(
960
961
962
963
        self,
        *,
        image_width: int,
        image_height: int,
964
        image_processor: Qwen2VLImageProcessor | None,
965
966
967
968
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
969
            num_frames=1,
970
            image_processor=image_processor,
971
972
973
        )
        return num_image_tokens

974
    def get_num_video_tokens(
975
976
977
978
979
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
980
        image_processor: Qwen2VLImageProcessor | None,
981
982
983
984
985
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
986
            image_processor=image_processor,
987
988
989
        )
        return num_video_tokens

990
    def get_image_size_with_most_features(self) -> ImageSize:
991
992
993
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
994
            num_frames=1,
995
            image_processor=None,
996
997
998
        )
        return max_image_size

999
1000
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1001

1002
        return self.get_num_image_tokens(
1003
1004
            image_width=target_width,
            image_height=target_height,
1005
            image_processor=None,
1006
        )
1007

1008
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1009
        target_width, target_height = self.get_image_size_with_most_features()
1010

1011
        num_frames = start_num_frames
1012
1013
1014

        while True:
            next_num_frames = num_frames + 1
1015
            next_max_tokens = self.get_num_video_tokens(
1016
1017
1018
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1019
                image_processor=None,
1020
            )
1021

1022
            if next_max_tokens > max_tokens:
1023
1024
1025
1026
1027
1028
                break

            num_frames = next_num_frames

        return num_frames

1029
1030
1031
1032
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1033
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1034
1035
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1036

1037
        max_total_frames = self._get_max_video_frames(seq_len)
1038
1039
1040
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1041

1042
        return max(max_frames_per_video, 1)
1043

1044
1045
1046
1047
1048
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1049
        target_width, target_height = self.get_image_size_with_most_features()
1050

1051
        return self.get_num_video_tokens(
1052
1053
            image_width=target_width,
            image_height=target_height,
1054
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1055
            image_processor=None,
1056
1057
        )

1058
1059

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1060
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1061
1062
1063
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1064
        hf_processor = self.info.get_hf_processor()
1065
1066
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1067

1068
1069
1070
1071
1072
1073
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1074
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1075
1076
1077
1078
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1079
1080
1081
1082
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1083

1084
1085
1086
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1087
        return {
1088
1089
1090
1091
1092
1093
1094
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1095
1096
                width=target_width,
                height=target_height,
1097
                num_frames=target_num_frames,
1098
                num_videos=num_videos,
1099
                overrides=video_overrides,
1100
            ),
1101
1102
        }

1103

1104
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1105
    def _get_data_parser(self) -> MultiModalDataParser:
1106
        return Qwen2VLMultiModalDataParser(
1107
1108
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1109

1110
    def _get_prompt_updates(
1111
1112
        self,
        mm_items: MultiModalDataItems,
1113
        hf_processor_mm_kwargs: Mapping[str, Any],
1114
        out_mm_kwargs: MultiModalKwargsItems,
1115
    ) -> Sequence[PromptUpdate]:
1116
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1117
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1118
1119
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1120
1121

        placeholder = {
1122
1123
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1124
        }
1125

1126
1127
1128
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1129
1130
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1131
1132
            assert isinstance(grid_thw, torch.Tensor)

1133
1134
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1135
1136
1137
1138

        return [
            PromptReplacement(
                modality=modality,
1139
                target=[placeholder[modality]],
1140
1141
1142
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1143
        ]
1144

1145
1146
1147
1148
1149
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1150
        return _create_qwen2vl_field_factory(
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1163
    merge_by_field_config = True
1164
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1165

1166
    # To ensure correct weight loading and mapping.
1167
1168
1169
1170
1171
1172
1173
1174
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1175
1176
        }
    )
1177

1178
1179
    supports_encoder_tp_data = True

1180
1181
1182
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1183
        mm_features: list[MultiModalFeatureSpec],
1184
    ) -> tuple[torch.Tensor, int]:
1185
1186
1187
1188
1189
1190
1191
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1192

1193
        hf_config = self.config
1194
1195
1196
1197
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1198
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1199
1200
1201

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1202
1203
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1230
                t, h, w = image_grid_thw[image_index]
1231
1232
1233
1234
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1235
                t, h, w = video_grid_thw[video_index]
1236
1237
1238
1239
1240
1241
1242
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1243
1244
1245
1246
1247
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1248
1249
            text_len = ed - st

1250
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1251
            llm_pos_ids_list.append(
1252
1253
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1254

1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1266

1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1279
            llm_pos_ids_list.append(
1280
1281
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1282
1283
1284
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1285
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1286
1287
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1288
1289
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1290
1291

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1292
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1293
1294
1295

        return llm_positions, mrope_position_delta

1296
    @classmethod
1297
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1298
1299
1300
1301
1302
1303
1304
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1305
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1306
        super().__init__()
1307
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1308
1309
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1310

1311
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1312
1313
1314
        self.config = config
        self.multimodal_config = multimodal_config

1315
1316
1317
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1318
1319
1320
1321
1322
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1323
1324
1325
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1326
                quant_config=quant_config,
1327
                prefix=maybe_prefix(prefix, "visual"),
1328
                use_data_parallel=self.use_data_parallel,
1329
                attn_backend_override=attn_backend_override,
1330
1331
1332
            )
        else:
            self.visual = None
1333

1334
1335
1336
1337
1338
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1339

1340
        self.make_empty_intermediate_tensors = (
1341
1342
            self.language_model.make_empty_intermediate_tensors
        )
1343
1344

    def _parse_and_validate_image_input(
1345
        self, **kwargs: object
1346
    ) -> Qwen2VLImageInputs | None:
1347
        pixel_values = kwargs.pop("pixel_values", None)
1348
        image_embeds = kwargs.pop("image_embeds", None)
1349
1350
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1351
        if pixel_values is None and image_embeds is None:
1352
1353
            return None

1354
        if pixel_values is not None:
1355
1356
1357
1358
1359
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1360
1361

        if image_embeds is not None:
1362
1363
1364
1365
1366
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1367
1368

    def _parse_and_validate_video_input(
1369
        self, **kwargs: object
1370
    ) -> Qwen2VLVideoInputs | None:
1371
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1372
        video_embeds = kwargs.pop("video_embeds", None)
1373
1374
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1375
        if pixel_values_videos is None and video_embeds is None:
1376
1377
            return None

1378
1379
1380
1381
1382
1383
1384
1385
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1386
1387
1388
1389
1390
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1391

1392
    def _process_image_input(
1393
1394
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1395
1396
1397
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1398
        if image_input["type"] == "image_embeds":
1399
            image_embeds = image_input["image_embeds"]
1400
        else:
1401
            pixel_values = image_input["pixel_values"]
1402
1403

            if self.use_data_parallel:
1404
                return run_dp_sharded_mrope_vision_model(
1405
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1406
                )
1407
            else:
1408
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1409
1410
1411

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1412
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1413
        return image_embeds.split(sizes)
1414
1415

    def _process_video_input(
1416
1417
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1418
1419
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1420

1421
        if video_input["type"] == "video_embeds":
1422
            video_embeds = video_input["video_embeds"]
1423
        else:
1424
            pixel_values_videos = video_input["pixel_values_videos"]
1425
            if self.use_data_parallel:
1426
                grid_thw_list = grid_thw.tolist()
1427
1428
1429
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1430
            else:
1431
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1432

1433
1434
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1435
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1436
        return video_embeds.split(sizes)
1437
1438
1439
1440
1441
1442
1443

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1454
1455

        return modalities
1456

1457
1458
1459
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1460
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1461
1462
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1463
            return []
1464

1465
1466
1467
1468
1469
1470
1471
1472
1473
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1474
1475
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1476
1477
1478
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1479
                multimodal_embeddings += tuple(video_embeddings)
1480
1481
1482

        return multimodal_embeddings

1483
1484
1485
1486
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1487
1488
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1489
        **kwargs: object,
1490
    ) -> torch.Tensor | IntermediateTensors:
1491
1492
1493
1494
1495
1496
1497
1498
1499
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1500
1501
1502
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1503
        """
1504

1505
        if intermediate_tensors is not None:
1506
            inputs_embeds = None
1507

1508
        hidden_states = self.language_model.model(
1509
1510
            input_ids=input_ids,
            positions=positions,
1511
            intermediate_tensors=intermediate_tensors,
1512
1513
1514
1515
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1516
1517
1518
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1519
    ) -> torch.Tensor | None:
1520
        return self.language_model.compute_logits(hidden_states)
1521

1522
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1523
1524
1525
1526
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1527
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1528
1529
1530
1531
1532
1533
1534

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1535
1536
1537
            connector="visual.merger.",
            tower_model="visual.",
        )
1538
1539
1540
1541
1542
1543
1544
1545
1546


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1547
        size: dict[str, int] | None = None,
1548
1549
1550
1551
1552
1553
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1554
                "longest_edge": size["max_pixels"],
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1569
1570
1571
1572
1573
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1574
1575
            **kwargs,
        )
1576
1577
1578
1579
1580


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1581
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1593
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1594
1595


1596
1597
1598
1599
1600
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1601
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1602
1603
1604
1605
1606
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1617
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1618
1619
1620
1621
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1622
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)