qwen2_vl.py 55.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31
32
33
34

import torch
import torch.nn as nn
import torch.nn.functional as F
35
from einops import rearrange
36
from transformers import BatchFeature
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
38
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
39
40
41
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
42
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
43
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
44

45
from vllm.attention.backends.registry import AttentionBackendEnum
46
47
48
49
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.conv import Conv3dLayer
57
58
59
60
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
from vllm.model_executor.layers.rotary_embedding import get_rope
63
from vllm.model_executor.layers.rotary_embedding.common import (
64
    apply_rotary_emb_torch,
65
66
    dispatch_rotary_emb_function,
)
67
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
68
from vllm.model_executor.models.module_mapping import MultiModelKeys
69
from vllm.multimodal import MULTIMODAL_REGISTRY
70
71
72
73
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
74
    MultiModalFeatureSpec,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
92
from vllm.multimodal.profiling import BaseDummyInputsBuilder
93
from vllm.sequence import IntermediateTensors
94
from vllm.transformers_utils.tokenizer import AnyTokenizer
95
from vllm.utils.tensor_schema import TensorSchema, TensorShape
96

97
98
99
100
101
102
103
104
105
106
107
108
109
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
110
111
112
113
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
114

115
116
logger = init_logger(__name__)

117
# For profile run
118
_MAX_FRAMES_PER_VIDEO = 14
119

120
121
122
# === Vision Inputs === #


123
class Qwen2VLImagePixelInputs(TensorSchema):
124
    """
125
126
127
128
129
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
130

131
    Historical context:
132
        - pixel_values shape: (num_patches, num_channels * patch_size *
133
134
135
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
136
    """
137

138
    type: Literal["pixel_values"]
139

140
141
142
143
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
144

145
146
147
148
149
150
151
152
153
154
155
156
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
157

158
159
160
161
162
163
164
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
165
    """
166

167
    type: Literal["image_embeds"]
168

169
170
171
172
173
174
175
176
177
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
178
179


180
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
181
182


183
184
185
186
187
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
188
        - ctps: Number of channels * temporal_patch_size * patch_size *
189
190
          patch_size
        - nv: Number of videos
191

192
    Historical context:
193
        - pixel_values_videos shape: (num_patches, num_channels *
194
195
196
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
197
    """
198

199
    type: Literal["pixel_values_videos"]
200

201
202
203
204
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
205

206
207
208
209
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
210
211


212
213
214
215
216
217
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
218

219
220
221
222
223
224
225
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
226
    """
227

228
    type: Literal["video_embeds"]
229

230
231
232
233
234
235
236
237
238
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
239
240


241
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
242

243
244
245
246
247
248
249
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
250
        hidden_features: int,
251
        act_layer: type[nn.Module] = QuickGELU,
252
        quant_config: QuantizationConfig | None = None,
253
        prefix: str = "",
254
        use_data_parallel: bool = False,
255
256
    ):
        super().__init__()
257
258
259
260
261
262
263
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
264
        self.act = act_layer()
265
266
267
268
269
270
271
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
272
273
274
275
276
277
278
279

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


280
281
def apply_rotary_pos_emb_vision(
    t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
282
) -> torch.Tensor:
283
284
    rotary_emb_function = dispatch_rotary_emb_function(
        default=partial(apply_rotary_emb_torch, is_neox_style=True)
285
    )
286
    output = rotary_emb_function(t, cos, sin).type_as(t)
287
288
289
290
291
292
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
293
294
295
        embed_dim: int,
        num_heads: int,
        projection_size: int,
296
        quant_config: QuantizationConfig | None = None,
297
        prefix: str = "",
298
        use_data_parallel: bool = False,
299
        attn_backend_override: AttentionBackendEnum | None = None,
300
301
302
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
303
304
305
306
307
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
308
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
309
        self.hidden_size_per_attention_head = dist_utils.divide(
310
311
            projection_size, num_heads
        )
312
        self.num_attention_heads_per_partition = dist_utils.divide(
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
330
331

        # Detect attention implementation.
332
333
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
334
            dtype=torch.get_default_dtype(),
335
            attn_backend_override=attn_backend_override,
336
        )
337
        self.use_upstream_fa = False
338

339
340
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
341
342
                self.attn_backend,
                self.use_upstream_fa,
343
                attn_backend_override=attn_backend_override,
344
            )
345
        )
346

347
        if self.attn_backend not in {
348
349
350
351
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.XFORMERS,
            AttentionBackendEnum.ROCM_AITER_FA,
352
353
        }:
            raise RuntimeError(
354
355
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
356

357
        self.is_flash_attn_backend = self.attn_backend in {
358
359
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
360
        }
361

362
363
364
365
366
367
368
369
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
370
371
372
373
374
375
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
376
377
378
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

379
    def forward(
380
381
382
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
383
384
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
385
386
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
387
    ) -> torch.Tensor:
388
389
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
390

391
392
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
393
394
        batch_size = q.shape[1]

395
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
396
397
398
399
400
401
402

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(
            qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
403

404
        if self.is_flash_attn_backend:
405
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
406

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
422
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
423
            # Execute attention entry by entry for speed & less VRAM.
424
425
426
427
428
429
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
430
            outputs = []
431
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
432
433
434
435
436
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
437
438
439
440
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
441
442
443
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
444
445
446
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
447
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
448
449
450
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

451
452
453
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
454
455

            context_layer = xops.memory_efficient_attention_forward(
456
457
458
459
460
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
461
462
463
464
465
466
467
468
469
470
471

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
472
        act_layer: type[nn.Module] = QuickGELU,
473
474
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
475
        prefix: str = "",
476
        use_data_parallel: bool = False,
477
        attn_backend_override: AttentionBackendEnum | None = None,
478
479
480
481
482
483
484
485
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

486
487
488
489
490
491
492
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
493
            attn_backend_override=attn_backend_override,
494
495
496
497
498
499
500
501
502
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
503

504
    def forward(
505
506
507
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
508
509
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
510
511
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
512
513
514
515
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
516
517
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
518
519
520
521
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

522
523
524
525
526
527
528
529
530
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
531
        in_channels: int = 3,
532
533
534
535
536
537
538
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

539
        kernel_size = (temporal_patch_size, patch_size, patch_size)
540
541
        self.proj = Conv3dLayer(
            in_channels,
542
            embed_dim,
543
544
            kernel_size=kernel_size,
            stride=kernel_size,
545
546
            bias=False,
        )
547
548

    def forward(self, x: torch.Tensor) -> torch.Tensor:
549
550
551
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
552
553
554
555
556
557
558
559
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
560
        norm_layer: Callable[[int], nn.Module] | None = None,
561
        spatial_merge_size: int = 2,
562
        quant_config: QuantizationConfig | None = None,
563
        prefix: str = "",
564
        use_data_parallel: bool = False,
565
566
567
568
569
570
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
609
        quant_config: QuantizationConfig | None = None,
610
        prefix: str = "",
611
        use_data_parallel: bool = False,
612
        attn_backend_override: AttentionBackendEnum | None = None,
613
614
615
    ) -> None:
        super().__init__()

616
617
618
619
620
621
622
623
624
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
625

626
627
628
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

629
        self.spatial_merge_size = spatial_merge_size
630
631
        self.num_heads = num_heads
        self.embed_dim = embed_dim
632
633
634
635

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
636
            in_channels=in_channels,
637
638
639
640
641
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
642
643
644
645
646
647
648
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            rotary_dim=head_dim // 2,
            max_position=8192,
            base=10000.0,
            is_neox_style=True,
        )
649

650
651
652
653
654
655
656
657
658
659
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
660
                    attn_backend_override=attn_backend_override,
661
662
663
664
                )
                for layer_idx in range(depth)
            ]
        )
665
666
667
668
669
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
670
            prefix=f"{prefix}.merger",
671
            use_data_parallel=use_data_parallel,
672
        )
673
        self.attn_backend = get_vit_attn_backend(
674
675
676
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
677
        )
678
679
680
        if (
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
681
        ):
682
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
683
684
685

    @property
    def dtype(self) -> torch.dtype:
686
        return self.patch_embed.proj.weight.dtype
687
688
689

    @property
    def device(self) -> torch.device:
690
        return self.patch_embed.proj.weight.device
691

692
693
694
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
695
        pos_ids = []
696
        max_grid_size = 0
697
698
699
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
721
            max_grid_size = max(max_grid_size, h, w)
722
        pos_ids = torch.cat(pos_ids, dim=0)
723
724
725
726
727
728
729
730
731
732
733
734

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

        cos_h = cos[pos_ids[:, 0]]  # (num_tokens, rotary_dim // 2)
        cos_w = cos[pos_ids[:, 1]]
        sin_h = sin[pos_ids[:, 0]]
        sin_w = sin[pos_ids[:, 1]]

        cos_combined = torch.cat([cos_h, cos_w], dim=-1)
        sin_combined = torch.cat([sin_h, sin_w], dim=-1)
        return cos_combined, sin_combined
735

736
    def compute_attn_mask_seqlen(
737
        self, cu_seqlens: torch.Tensor
738
    ) -> tuple[int | None, list[int] | None]:
739
        max_seqlen, seqlens = None, None
740
741
742
743
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
744
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
745
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
746
747
748
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

749
750
751
    def forward(
        self,
        x: torch.Tensor,
752
        grid_thw: torch.Tensor | list[list[int]],
753
754
755
756
757
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

758
759
760
761
762
763
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
        else:
            grid_thw_list = grid_thw.tolist()

764
        # compute position embedding
765
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
766
767

        # compute cu_seqlens
768
        cu_seqlens = torch.repeat_interleave(
769
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
770
        ).cumsum(dim=0, dtype=torch.int32)
771
772
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
773
774
775

        # transformers
        x = x.unsqueeze(1)
776

777
778
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
779
        for blk in self.blocks:
780
781
782
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
783
784
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
785
786
787
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
788
789
790

        # adapter
        x = self.merger(x)
791

792
793
        return x

794
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
795
796
797
798
799
800
801
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
802
        loaded_params: set[str] = set()
803
804

        for name, loaded_weight in weights:
805
            for param_name, weight_name, shard_id in stacked_params_mapping:
806
807
808
809
810
811
812
813
814
815
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
816
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
817
818
819
820
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

821

822
def _create_qwen2vl_field_factory(
823
    spatial_merge_size: int,
824
825
) -> Callable[
    [Mapping[str, torch.Tensor]],
826
    Mapping[str, MultiModalFieldConfig],
827
828
829
830
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
831
832
833
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
834
835
836

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
837
838
839
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
840
841
842

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
843
844
                "image", image_pixel_grid_sizes
            ),
845
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
846
847
                "image", image_embed_grid_sizes
            ),
848
849
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
850
851
                "video", video_grid_sizes
            ),
852
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
853
854
                "video", video_embed_grid_sizes
            ),
855
856
857
858
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
859

860

Roger Wang's avatar
Roger Wang committed
861
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
862
863
864
865
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

866
867
    def _parse_image_data(
        self,
868
869
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
870
        if isinstance(data, dict):
871
872
873
874
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
875
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
876
            )
877
878
879
880

        return super()._parse_image_data(data)

    def _parse_video_data(
881
        self,
882
883
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
884
        if isinstance(data, dict):
885
886
887
888
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
889
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
890
            )
891
892
893
894

        return super()._parse_video_data(data)


895
896
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
897
898
        return self.ctx.get_hf_config(Qwen2VLConfig)

899
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
900
901
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
902
            use_fast=kwargs.pop("use_fast", True),
903
904
905
            **kwargs,
        )

906
907
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
908

909
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
910
911
        return {"image": None, "video": None}

912
913
914
915
916
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
917
918
919
920
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

921
922
923
924
925
926
927
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
928
        image_processor: Qwen2VLImageProcessor | None,
929
    ) -> tuple[ImageSize, int]:
930
931
932
933
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
934
        vision_config = hf_config.vision_config
935
936
937
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
938

939
940
941
942
943
944
945
946
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
947
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
948
        else:
949
            preprocessed_size = ImageSize(width=image_width, height=image_height)
950

951
952
953
954
955
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
956
957
958
959
960
961
962
963
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

964
    def get_num_image_tokens(
965
966
967
968
        self,
        *,
        image_width: int,
        image_height: int,
969
        image_processor: Qwen2VLImageProcessor | None,
970
971
972
973
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
974
            num_frames=1,
975
            image_processor=image_processor,
976
977
978
        )
        return num_image_tokens

979
    def get_num_video_tokens(
980
981
982
983
984
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
985
        image_processor: Qwen2VLImageProcessor | None,
986
987
988
989
990
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
991
            image_processor=image_processor,
992
993
994
        )
        return num_video_tokens

995
    def get_image_size_with_most_features(self) -> ImageSize:
996
997
998
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
999
            num_frames=1,
1000
            image_processor=None,
1001
1002
1003
        )
        return max_image_size

1004
1005
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1006

1007
        return self.get_num_image_tokens(
1008
1009
            image_width=target_width,
            image_height=target_height,
1010
            image_processor=None,
1011
        )
1012

1013
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1014
        target_width, target_height = self.get_image_size_with_most_features()
1015

1016
        num_frames = start_num_frames
1017
1018
1019

        while True:
            next_num_frames = num_frames + 1
1020
            next_max_tokens = self.get_num_video_tokens(
1021
1022
1023
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1024
                image_processor=None,
1025
            )
1026

1027
            if next_max_tokens > max_tokens:
1028
1029
1030
1031
1032
1033
                break

            num_frames = next_num_frames

        return num_frames

1034
1035
1036
1037
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1038
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1039
1040
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1041

1042
        max_total_frames = self._get_max_video_frames(seq_len)
1043
1044
1045
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1046

1047
        return max(max_frames_per_video, 1)
1048

1049
1050
1051
1052
1053
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1054
        target_width, target_height = self.get_image_size_with_most_features()
1055

1056
        return self.get_num_video_tokens(
1057
1058
            image_width=target_width,
            image_height=target_height,
1059
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1060
            image_processor=None,
1061
1062
        )

1063
1064

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1065
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1066
1067
1068
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1069
        hf_processor = self.info.get_hf_processor()
1070
1071
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1072

1073
1074
1075
1076
1077
1078
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1079
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1080
1081
1082
1083
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1084
1085
1086
1087
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1088

1089
1090
1091
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1092
        return {
1093
1094
1095
1096
1097
1098
1099
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1100
1101
                width=target_width,
                height=target_height,
1102
                num_frames=target_num_frames,
1103
                num_videos=num_videos,
1104
                overrides=video_overrides,
1105
            ),
1106
1107
        }

1108

1109
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1110
    def _get_data_parser(self) -> MultiModalDataParser:
1111
        return Qwen2VLMultiModalDataParser(
1112
1113
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1114

1115
    def _get_prompt_updates(
1116
1117
        self,
        mm_items: MultiModalDataItems,
1118
        hf_processor_mm_kwargs: Mapping[str, Any],
1119
        out_mm_kwargs: MultiModalKwargsItems,
1120
    ) -> Sequence[PromptUpdate]:
1121
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1122
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1123
1124
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1125
1126

        placeholder = {
1127
1128
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1129
        }
1130

1131
1132
1133
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1134
1135
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1136
1137
            assert isinstance(grid_thw, torch.Tensor)

1138
1139
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1140
1141
1142
1143

        return [
            PromptReplacement(
                modality=modality,
1144
                target=[placeholder[modality]],
1145
1146
1147
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1148
        ]
1149

1150
1151
1152
1153
1154
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1155
        return _create_qwen2vl_field_factory(
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1168
    merge_by_field_config = True
1169
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1170

1171
    # To ensure correct weight loading and mapping.
1172
1173
1174
1175
1176
1177
1178
1179
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1180
1181
        }
    )
1182

1183
1184
    supports_encoder_tp_data = True

1185
1186
1187
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1188
        mm_features: list[MultiModalFeatureSpec],
1189
    ) -> tuple[torch.Tensor, int]:
1190
1191
1192
1193
1194
1195
1196
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1197

1198
        hf_config = self.config
1199
1200
1201
1202
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1203
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1204
1205
1206

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1207
1208
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1235
                t, h, w = image_grid_thw[image_index]
1236
1237
1238
1239
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1240
                t, h, w = video_grid_thw[video_index]
1241
1242
1243
1244
1245
1246
1247
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1248
1249
1250
1251
1252
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1253
1254
            text_len = ed - st

1255
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1256
            llm_pos_ids_list.append(
1257
1258
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1259

1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1271

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1284
            llm_pos_ids_list.append(
1285
1286
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1287
1288
1289
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1290
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1291
1292
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1293
1294
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1295
1296

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1297
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1298
1299
1300

        return llm_positions, mrope_position_delta

1301
    @classmethod
1302
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1303
1304
1305
1306
1307
1308
1309
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1310
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1311
        super().__init__()
1312
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1313
1314
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1315

1316
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1317
1318
1319
        self.config = config
        self.multimodal_config = multimodal_config

1320
1321
1322
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1323
1324
1325
1326
1327
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1328
1329
1330
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1331
                quant_config=quant_config,
1332
                prefix=maybe_prefix(prefix, "visual"),
1333
                use_data_parallel=self.use_data_parallel,
1334
                attn_backend_override=attn_backend_override,
1335
1336
1337
            )
        else:
            self.visual = None
1338

1339
1340
1341
1342
1343
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1344

1345
        self.make_empty_intermediate_tensors = (
1346
1347
            self.language_model.make_empty_intermediate_tensors
        )
1348
1349

    def _parse_and_validate_image_input(
1350
        self, **kwargs: object
1351
    ) -> Qwen2VLImageInputs | None:
1352
        pixel_values = kwargs.pop("pixel_values", None)
1353
        image_embeds = kwargs.pop("image_embeds", None)
1354
1355
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1356
        if pixel_values is None and image_embeds is None:
1357
1358
            return None

1359
        if pixel_values is not None:
1360
1361
1362
1363
1364
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1365
1366

        if image_embeds is not None:
1367
1368
1369
1370
1371
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1372
1373

    def _parse_and_validate_video_input(
1374
        self, **kwargs: object
1375
    ) -> Qwen2VLVideoInputs | None:
1376
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1377
        video_embeds = kwargs.pop("video_embeds", None)
1378
1379
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1380
        if pixel_values_videos is None and video_embeds is None:
1381
1382
            return None

1383
1384
1385
1386
1387
1388
1389
1390
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1391
1392
1393
1394
1395
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1396

1397
    def _process_image_input(
1398
1399
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1400
1401
1402
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1403
        if image_input["type"] == "image_embeds":
1404
            image_embeds = image_input["image_embeds"]
1405
        else:
1406
            pixel_values = image_input["pixel_values"]
1407
1408

            if self.use_data_parallel:
1409
                return run_dp_sharded_mrope_vision_model(
1410
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1411
                )
1412
            else:
1413
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1414
1415
1416

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1417
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1418
        return image_embeds.split(sizes)
1419
1420

    def _process_video_input(
1421
1422
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1423
1424
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1425

1426
        if video_input["type"] == "video_embeds":
1427
            video_embeds = video_input["video_embeds"]
1428
        else:
1429
            pixel_values_videos = video_input["pixel_values_videos"]
1430
            if self.use_data_parallel:
1431
                grid_thw_list = grid_thw.tolist()
1432
1433
1434
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1435
            else:
1436
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1437

1438
1439
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1440
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1441
        return video_embeds.split(sizes)
1442
1443
1444
1445
1446
1447
1448

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1459
1460

        return modalities
1461

1462
1463
1464
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1465
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1466
1467
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1468
            return []
1469

1470
1471
1472
1473
1474
1475
1476
1477
1478
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1479
1480
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1481
1482
1483
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1484
                multimodal_embeddings += tuple(video_embeddings)
1485
1486
1487

        return multimodal_embeddings

1488
1489
1490
1491
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1492
1493
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1494
        **kwargs: object,
1495
    ) -> torch.Tensor | IntermediateTensors:
1496
1497
1498
1499
1500
1501
1502
1503
1504
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1505
1506
1507
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1508
        """
1509

1510
        if intermediate_tensors is not None:
1511
            inputs_embeds = None
1512

1513
        hidden_states = self.language_model.model(
1514
1515
            input_ids=input_ids,
            positions=positions,
1516
            intermediate_tensors=intermediate_tensors,
1517
1518
1519
1520
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1521
1522
1523
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1524
    ) -> torch.Tensor | None:
1525
        return self.language_model.compute_logits(hidden_states)
1526

1527
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1528
1529
1530
1531
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1532
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1533
1534
1535
1536
1537
1538
1539

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1540
1541
1542
            connector="visual.merger.",
            tower_model="visual.",
        )
1543
1544
1545
1546
1547
1548
1549
1550
1551


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1552
        size: dict[str, int] | None = None,
1553
1554
1555
1556
1557
1558
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1559
                "longest_edge": size["max_pixels"],
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1574
1575
1576
1577
1578
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1579
1580
            **kwargs,
        )
1581
1582
1583
1584
1585


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1586
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1598
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1599
1600


1601
1602
1603
1604
1605
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1606
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1607
1608
1609
1610
1611
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1622
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1623
1624
1625
1626
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1627
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)