qwen2_vl.py 54.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31

32
import numpy as np
33
34
35
import torch
import torch.nn as nn
import torch.nn.functional as F
36
from einops import rearrange
37
from transformers import BatchFeature
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.attention.backends.registry import AttentionBackendEnum
47
48
49
50
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
51
from vllm.config import VllmConfig
52
from vllm.config.multimodal import BaseDummyOptions
53
from vllm.distributed import parallel_state
54
55
56
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
57
from vllm.model_executor.layers.conv import Conv3dLayer
58
59
60
61
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.rotary_embedding import get_rope
64
from vllm.model_executor.layers.rotary_embedding.common import (
65
    apply_rotary_emb_torch,
66
67
    dispatch_rotary_emb_function,
)
68
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
69
from vllm.model_executor.models.module_mapping import MultiModelKeys
70
from vllm.multimodal import MULTIMODAL_REGISTRY
71
72
73
74
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
75
    MultiModalFeatureSpec,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
93
from vllm.multimodal.profiling import BaseDummyInputsBuilder
94
from vllm.sequence import IntermediateTensors
95
from vllm.transformers_utils.tokenizer import AnyTokenizer
96
from vllm.utils.tensor_schema import TensorSchema, TensorShape
97

98
99
100
101
102
103
104
105
106
107
108
109
110
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
111
112
113
114
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
115

116
117
logger = init_logger(__name__)

118
# For profile run
119
_MAX_FRAMES_PER_VIDEO = 14
120

121
122
123
# === Vision Inputs === #


124
class Qwen2VLImagePixelInputs(TensorSchema):
125
    """
126
127
128
129
130
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
131

132
    Historical context:
133
        - pixel_values shape: (num_patches, num_channels * patch_size *
134
135
136
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
137
    """
138

139
    type: Literal["pixel_values"]
140

141
142
143
144
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
145

146
147
148
149
150
151
152
153
154
155
156
157
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
158

159
160
161
162
163
164
165
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
166
    """
167

168
    type: Literal["image_embeds"]
169

170
171
172
173
174
175
176
177
178
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
179
180


181
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
182
183


184
185
186
187
188
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
189
        - ctps: Number of channels * temporal_patch_size * patch_size *
190
191
          patch_size
        - nv: Number of videos
192

193
    Historical context:
194
        - pixel_values_videos shape: (num_patches, num_channels *
195
196
197
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
198
    """
199

200
    type: Literal["pixel_values_videos"]
201

202
203
204
205
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
206

207
208
209
210
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
211
212


213
214
215
216
217
218
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
219

220
221
222
223
224
225
226
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
227
    """
228

229
    type: Literal["video_embeds"]
230

231
232
233
234
235
236
237
238
239
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
240
241


242
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
243

244
245
246
247
248
249
250
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
251
        hidden_features: int,
252
        act_layer: type[nn.Module] = QuickGELU,
253
        quant_config: QuantizationConfig | None = None,
254
        prefix: str = "",
255
        use_data_parallel: bool = False,
256
257
    ):
        super().__init__()
258
259
260
261
262
263
264
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
265
        self.act = act_layer()
266
267
268
269
270
271
272
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
273
274
275
276
277
278
279
280

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


281
282
def apply_rotary_pos_emb_vision(
    t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
283
) -> torch.Tensor:
284
285
    rotary_emb_function = dispatch_rotary_emb_function(
        default=partial(apply_rotary_emb_torch, is_neox_style=True)
286
    )
287
    output = rotary_emb_function(t, cos, sin).type_as(t)
288
289
290
291
292
293
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
294
295
296
        embed_dim: int,
        num_heads: int,
        projection_size: int,
297
        quant_config: QuantizationConfig | None = None,
298
        prefix: str = "",
299
        use_data_parallel: bool = False,
300
        attn_backend_override: AttentionBackendEnum | None = None,
301
302
303
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
304
305
306
307
308
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
309
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
310
        self.hidden_size_per_attention_head = dist_utils.divide(
311
312
            projection_size, num_heads
        )
313
        self.num_attention_heads_per_partition = dist_utils.divide(
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
331
332

        # Detect attention implementation.
333
334
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
335
            dtype=torch.get_default_dtype(),
336
            attn_backend_override=attn_backend_override,
337
        )
338
        self.use_upstream_fa = False
339

340
341
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
342
343
                self.attn_backend,
                self.use_upstream_fa,
344
                attn_backend_override=attn_backend_override,
345
            )
346
        )
347

348
        if self.attn_backend not in {
349
350
351
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
352
353
        }:
            raise RuntimeError(
354
355
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
356

357
        self.is_flash_attn_backend = self.attn_backend in {
358
359
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
360
        }
361

362
363
364
365
366
367
368
369
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
370
371
372
373
374
375
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
376
377
378
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

379
    def forward(
380
381
382
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
383
384
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
385
        max_seqlen: int | None = None,  # Only used for Flash Attention
386
    ) -> torch.Tensor:
387
388
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
389

390
391
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
392
393
        batch_size = q.shape[1]

394
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
395
396
397
398
399
400
401

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(
            qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
402

403
        if self.is_flash_attn_backend:
404
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
421
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
422
            # Execute attention entry by entry for speed & less VRAM.
423
424
425
426
427
428
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
429
            outputs = []
430
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
431
432
433
434
435
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
436
437
438
439
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
440
441
442
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
443
444
445
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
446
447
448
449
450
451
452
453
454
455
456

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
457
        act_layer: type[nn.Module] = QuickGELU,
458
459
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
460
        prefix: str = "",
461
        use_data_parallel: bool = False,
462
        attn_backend_override: AttentionBackendEnum | None = None,
463
464
465
466
467
468
469
470
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

471
472
473
474
475
476
477
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
478
            attn_backend_override=attn_backend_override,
479
480
481
482
483
484
485
486
487
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
488

489
    def forward(
490
491
492
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
493
494
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
495
        max_seqlen: int | None = None,  # Only used for Flash Attention
496
497
498
499
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
500
501
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
502
503
504
            max_seqlen=max_seqlen,
        )

505
506
507
508
509
510
511
512
513
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
514
        in_channels: int = 3,
515
516
517
518
519
520
521
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

522
        kernel_size = (temporal_patch_size, patch_size, patch_size)
523
524
        self.proj = Conv3dLayer(
            in_channels,
525
            embed_dim,
526
527
            kernel_size=kernel_size,
            stride=kernel_size,
528
529
            bias=False,
        )
530
531

    def forward(self, x: torch.Tensor) -> torch.Tensor:
532
533
534
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
535
536
537
538
539
540
541
542
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
543
        norm_layer: Callable[[int], nn.Module] | None = None,
544
        spatial_merge_size: int = 2,
545
        quant_config: QuantizationConfig | None = None,
546
        prefix: str = "",
547
        use_data_parallel: bool = False,
548
549
550
551
552
553
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
592
        quant_config: QuantizationConfig | None = None,
593
        prefix: str = "",
594
        use_data_parallel: bool = False,
595
        attn_backend_override: AttentionBackendEnum | None = None,
596
597
598
    ) -> None:
        super().__init__()

599
600
601
602
603
604
605
606
607
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
608

609
610
611
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

612
        self.spatial_merge_size = spatial_merge_size
613
614
        self.num_heads = num_heads
        self.embed_dim = embed_dim
615
616
617
618

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
619
            in_channels=in_channels,
620
621
622
623
624
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
625
626
627
628
629
630
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            rotary_dim=head_dim // 2,
            max_position=8192,
            is_neox_style=True,
        )
631

632
633
634
635
636
637
638
639
640
641
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
642
                    attn_backend_override=attn_backend_override,
643
644
645
646
                )
                for layer_idx in range(depth)
            ]
        )
647
648
649
650
651
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
652
            prefix=f"{prefix}.merger",
653
            use_data_parallel=use_data_parallel,
654
        )
655
        self.attn_backend = get_vit_attn_backend(
656
657
658
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
659
        )
660
661
662
        if (
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
663
        ):
664
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
665
666
667

    @property
    def dtype(self) -> torch.dtype:
668
        return self.patch_embed.proj.weight.dtype
669
670
671

    @property
    def device(self) -> torch.device:
672
        return self.patch_embed.proj.weight.device
673

674
675
676
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
677
        pos_ids = []
678
        max_grid_size = 0
679
680
681
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
703
            max_grid_size = max(max_grid_size, h, w)
704
        pos_ids = torch.cat(pos_ids, dim=0)
705
706
707
708

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

709
710
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
711
        return cos_combined, sin_combined
712

713
714
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
715
716
717
718
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
719
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
720
        return max_seqlen
721

722
723
724
    def forward(
        self,
        x: torch.Tensor,
725
        grid_thw: torch.Tensor | list[list[int]],
726
727
728
729
730
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

731
732
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
733
            grid_thw = np.array(grid_thw, dtype=np.int32)
734
735
        else:
            grid_thw_list = grid_thw.tolist()
736
            grid_thw = grid_thw.numpy()
737

738
        # compute position embedding
739
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
740
741

        # compute cu_seqlens
742
743
744
745
746
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
747
748
749

        # transformers
        x = x.unsqueeze(1)
750

751
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
752
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
753
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
754
        for blk in self.blocks:
755
756
757
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
758
759
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
760
761
                max_seqlen=max_seqlen,
            )
762
763
764

        # adapter
        x = self.merger(x)
765

766
767
        return x

768
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
769
770
771
772
773
774
775
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
776
        loaded_params: set[str] = set()
777
778

        for name, loaded_weight in weights:
779
            for param_name, weight_name, shard_id in stacked_params_mapping:
780
781
782
783
784
785
786
787
788
789
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
790
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
791
792
793
794
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

795

796
def _create_qwen2vl_field_factory(
797
    spatial_merge_size: int,
798
799
) -> Callable[
    [Mapping[str, torch.Tensor]],
800
    Mapping[str, MultiModalFieldConfig],
801
802
803
804
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
805
806
807
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
808
809
810

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
811
812
813
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
814
815
816

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
817
818
                "image", image_pixel_grid_sizes
            ),
819
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
820
821
                "image", image_embed_grid_sizes
            ),
822
823
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
824
825
                "video", video_grid_sizes
            ),
826
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
827
828
                "video", video_embed_grid_sizes
            ),
829
830
831
832
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
833

834

Roger Wang's avatar
Roger Wang committed
835
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
836
837
838
839
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

840
841
    def _parse_image_data(
        self,
842
843
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
844
        if isinstance(data, dict):
845
846
847
848
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
849
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
850
            )
851
852
853
854

        return super()._parse_image_data(data)

    def _parse_video_data(
855
        self,
856
857
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
858
        if isinstance(data, dict):
859
860
861
862
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
863
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
864
            )
865
866
867
868

        return super()._parse_video_data(data)


869
870
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
871
872
        return self.ctx.get_hf_config(Qwen2VLConfig)

873
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
874
875
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
876
            use_fast=kwargs.pop("use_fast", True),
877
878
879
            **kwargs,
        )

880
881
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
882

883
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
884
885
        return {"image": None, "video": None}

886
887
888
889
890
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
891
892
893
894
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

895
896
897
898
899
900
901
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
902
        image_processor: Qwen2VLImageProcessor | None,
903
    ) -> tuple[ImageSize, int]:
904
905
906
907
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
908
        vision_config = hf_config.vision_config
909
910
911
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
912

913
914
915
916
917
918
919
920
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
921
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
922
        else:
923
            preprocessed_size = ImageSize(width=image_width, height=image_height)
924

925
926
927
928
929
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
930
931
932
933
934
935
936
937
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

938
    def get_num_image_tokens(
939
940
941
942
        self,
        *,
        image_width: int,
        image_height: int,
943
        image_processor: Qwen2VLImageProcessor | None,
944
945
946
947
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
948
            num_frames=1,
949
            image_processor=image_processor,
950
951
952
        )
        return num_image_tokens

953
    def get_num_video_tokens(
954
955
956
957
958
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
959
        image_processor: Qwen2VLImageProcessor | None,
960
961
962
963
964
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
965
            image_processor=image_processor,
966
967
968
        )
        return num_video_tokens

969
    def get_image_size_with_most_features(self) -> ImageSize:
970
971
972
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
973
            num_frames=1,
974
            image_processor=None,
975
976
977
        )
        return max_image_size

978
979
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
980

981
        return self.get_num_image_tokens(
982
983
            image_width=target_width,
            image_height=target_height,
984
            image_processor=None,
985
        )
986

987
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
988
        target_width, target_height = self.get_image_size_with_most_features()
989

990
        num_frames = start_num_frames
991
992
993

        while True:
            next_num_frames = num_frames + 1
994
            next_max_tokens = self.get_num_video_tokens(
995
996
997
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
998
                image_processor=None,
999
            )
1000

1001
            if next_max_tokens > max_tokens:
1002
1003
1004
1005
1006
1007
                break

            num_frames = next_num_frames

        return num_frames

1008
1009
1010
1011
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1012
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1013
1014
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1015

1016
        max_total_frames = self._get_max_video_frames(seq_len)
1017
1018
1019
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1020

1021
        return max(max_frames_per_video, 1)
1022

1023
1024
1025
1026
1027
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1028
        target_width, target_height = self.get_image_size_with_most_features()
1029

1030
        return self.get_num_video_tokens(
1031
1032
            image_width=target_width,
            image_height=target_height,
1033
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1034
            image_processor=None,
1035
1036
        )

1037
1038

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1039
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1040
1041
1042
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1043
        hf_processor = self.info.get_hf_processor()
1044
1045
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1046

1047
1048
1049
1050
1051
1052
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1053
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1054
1055
1056
1057
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1058
1059
1060
1061
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1062

1063
1064
1065
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1066
        return {
1067
1068
1069
1070
1071
1072
1073
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1074
1075
                width=target_width,
                height=target_height,
1076
                num_frames=target_num_frames,
1077
                num_videos=num_videos,
1078
                overrides=video_overrides,
1079
            ),
1080
1081
        }

1082

1083
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1084
    def _get_data_parser(self) -> MultiModalDataParser:
1085
        return Qwen2VLMultiModalDataParser(
1086
1087
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1088

1089
    def _get_prompt_updates(
1090
1091
        self,
        mm_items: MultiModalDataItems,
1092
        hf_processor_mm_kwargs: Mapping[str, Any],
1093
        out_mm_kwargs: MultiModalKwargsItems,
1094
    ) -> Sequence[PromptUpdate]:
1095
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1096
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1097
1098
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1099
1100

        placeholder = {
1101
1102
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1103
        }
1104

1105
1106
1107
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1108
1109
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1110
1111
            assert isinstance(grid_thw, torch.Tensor)

1112
1113
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1114
1115
1116
1117

        return [
            PromptReplacement(
                modality=modality,
1118
                target=[placeholder[modality]],
1119
1120
1121
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1122
        ]
1123

1124
1125
1126
1127
1128
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1129
        return _create_qwen2vl_field_factory(
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1142
    merge_by_field_config = True
1143
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1144

1145
    # To ensure correct weight loading and mapping.
1146
1147
1148
1149
1150
1151
1152
1153
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1154
1155
        }
    )
1156

1157
1158
    supports_encoder_tp_data = True

1159
1160
1161
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1162
        mm_features: list[MultiModalFeatureSpec],
1163
    ) -> tuple[torch.Tensor, int]:
1164
1165
1166
1167
1168
1169
1170
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1171

1172
        hf_config = self.config
1173
1174
1175
1176
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1177
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1178
1179
1180

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1181
1182
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1209
                t, h, w = image_grid_thw[image_index]
1210
1211
1212
1213
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1214
                t, h, w = video_grid_thw[video_index]
1215
1216
1217
1218
1219
1220
1221
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1222
1223
1224
1225
1226
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1227
1228
            text_len = ed - st

1229
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1230
            llm_pos_ids_list.append(
1231
1232
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1233

1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1245

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1258
            llm_pos_ids_list.append(
1259
1260
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1261
1262
1263
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1264
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1265
1266
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1267
1268
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1269
1270

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1271
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1272
1273
1274

        return llm_positions, mrope_position_delta

1275
    @classmethod
1276
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1277
1278
1279
1280
1281
1282
1283
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1284
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1285
        super().__init__()
1286
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1287
1288
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1289

1290
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1291
1292
1293
        self.config = config
        self.multimodal_config = multimodal_config

1294
1295
1296
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1297
1298
1299
1300
1301
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1302
1303
1304
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1305
                quant_config=quant_config,
1306
                prefix=maybe_prefix(prefix, "visual"),
1307
                use_data_parallel=self.use_data_parallel,
1308
                attn_backend_override=attn_backend_override,
1309
1310
1311
            )
        else:
            self.visual = None
1312

1313
1314
1315
1316
1317
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1318

1319
        self.make_empty_intermediate_tensors = (
1320
1321
            self.language_model.make_empty_intermediate_tensors
        )
1322
1323

    def _parse_and_validate_image_input(
1324
        self, **kwargs: object
1325
    ) -> Qwen2VLImageInputs | None:
1326
        pixel_values = kwargs.pop("pixel_values", None)
1327
        image_embeds = kwargs.pop("image_embeds", None)
1328
1329
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1330
        if pixel_values is None and image_embeds is None:
1331
1332
            return None

1333
        if pixel_values is not None:
1334
1335
1336
1337
1338
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1339
1340

        if image_embeds is not None:
1341
1342
1343
1344
1345
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1346
1347

    def _parse_and_validate_video_input(
1348
        self, **kwargs: object
1349
    ) -> Qwen2VLVideoInputs | None:
1350
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1351
        video_embeds = kwargs.pop("video_embeds", None)
1352
1353
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1354
        if pixel_values_videos is None and video_embeds is None:
1355
1356
            return None

1357
1358
1359
1360
1361
1362
1363
1364
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1365
1366
1367
1368
1369
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1370

1371
    def _process_image_input(
1372
1373
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1374
1375
1376
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1377
        if image_input["type"] == "image_embeds":
1378
            image_embeds = image_input["image_embeds"]
1379
        else:
1380
            pixel_values = image_input["pixel_values"]
1381
1382

            if self.use_data_parallel:
1383
                return run_dp_sharded_mrope_vision_model(
1384
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1385
                )
1386
            else:
1387
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1388
1389
1390

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1391
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1392
        return image_embeds.split(sizes)
1393
1394

    def _process_video_input(
1395
1396
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1397
1398
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1399

1400
        if video_input["type"] == "video_embeds":
1401
            video_embeds = video_input["video_embeds"]
1402
        else:
1403
            pixel_values_videos = video_input["pixel_values_videos"]
1404
            if self.use_data_parallel:
1405
                grid_thw_list = grid_thw.tolist()
1406
1407
1408
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1409
            else:
1410
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1411

1412
1413
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1414
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1415
        return video_embeds.split(sizes)
1416
1417
1418
1419
1420
1421
1422

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1433
1434

        return modalities
1435

1436
1437
1438
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1439
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1440
1441
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1442
            return []
1443

1444
1445
1446
1447
1448
1449
1450
1451
1452
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1453
1454
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1455
1456
1457
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1458
                multimodal_embeddings += tuple(video_embeddings)
1459
1460
1461

        return multimodal_embeddings

1462
1463
1464
1465
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1466
1467
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1468
        **kwargs: object,
1469
    ) -> torch.Tensor | IntermediateTensors:
1470
1471
1472
1473
1474
1475
1476
1477
1478
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1479
1480
1481
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1482
        """
1483

1484
        if intermediate_tensors is not None:
1485
            inputs_embeds = None
1486

1487
        hidden_states = self.language_model.model(
1488
1489
            input_ids=input_ids,
            positions=positions,
1490
            intermediate_tensors=intermediate_tensors,
1491
1492
1493
1494
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1495
1496
1497
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1498
    ) -> torch.Tensor | None:
1499
        return self.language_model.compute_logits(hidden_states)
1500

1501
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1502
1503
1504
1505
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1506
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1507
1508
1509
1510
1511
1512
1513

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1514
1515
1516
            connector="visual.merger.",
            tower_model="visual.",
        )
1517
1518
1519
1520
1521
1522
1523
1524
1525


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1526
        size: dict[str, int] | None = None,
1527
1528
1529
1530
1531
1532
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1533
                "longest_edge": size["max_pixels"],
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1548
1549
1550
1551
1552
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1553
1554
            **kwargs,
        )
1555
1556
1557
1558
1559


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1560
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1572
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1573
1574


1575
1576
1577
1578
1579
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1580
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1581
1582
1583
1584
1585
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1596
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1597
1598
1599
1600
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1601
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)