qwen2_vl.py 55.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32

33
import numpy as np
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
37
from einops import rearrange
38
from transformers import BatchFeature
39
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
40
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
41
42
43
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
44
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
45
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
46

47
from vllm.attention.backends.registry import AttentionBackendEnum
48
49
50
from vllm.attention.layer import (
    maybe_get_vit_flash_attn_backend,
)
51
from vllm.config import VllmConfig
52
from vllm.config.multimodal import BaseDummyOptions
53
from vllm.distributed import parallel_state
54
55
56
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
57
from vllm.model_executor.layers.conv import Conv3dLayer
58
59
60
61
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.rotary_embedding import get_rope
64
from vllm.model_executor.layers.rotary_embedding.common import (
65
    apply_rotary_emb_torch,
66
67
    dispatch_rotary_emb_function,
)
68
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
69
from vllm.model_executor.models.module_mapping import MultiModelKeys
70
from vllm.multimodal import MULTIMODAL_REGISTRY
71
72
73
74
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
75
    MultiModalFeatureSpec,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
93
from vllm.multimodal.profiling import BaseDummyInputsBuilder
94
from vllm.sequence import IntermediateTensors
95
from vllm.tokenizers import TokenizerLike
96
from vllm.utils.tensor_schema import TensorSchema, TensorShape
97

98
99
100
101
102
103
104
105
106
107
108
109
110
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
111
112
113
114
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
115

116
117
logger = init_logger(__name__)

118
# For profile run
119
_MAX_FRAMES_PER_VIDEO = 14
120

121
122
123
# === Vision Inputs === #


124
class Qwen2VLImagePixelInputs(TensorSchema):
125
    """
126
127
128
129
130
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
131

132
    Historical context:
133
        - pixel_values shape: (num_patches, num_channels * patch_size *
134
135
136
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
137
    """
138

139
    type: Literal["pixel_values"]
140

141
142
143
144
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
145

146
147
148
149
150
151
152
153
154
155
156
157
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
158

159
160
161
162
163
164
165
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
166
    """
167

168
    type: Literal["image_embeds"]
169

170
171
172
173
174
175
176
177
178
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
179
180


181
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
182
183


184
185
186
187
188
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
189
        - ctps: Number of channels * temporal_patch_size * patch_size *
190
191
          patch_size
        - nv: Number of videos
192

193
    Historical context:
194
        - pixel_values_videos shape: (num_patches, num_channels *
195
196
197
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
198
    """
199

200
    type: Literal["pixel_values_videos"]
201

202
203
204
205
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
206

207
208
209
210
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
211
212


213
214
215
216
217
218
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
219

220
221
222
223
224
225
226
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
227
    """
228

229
    type: Literal["video_embeds"]
230

231
232
233
234
235
236
237
238
239
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
240
241


242
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
243

244
245
246
247
248
249
250
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
251
        hidden_features: int,
252
        act_layer: type[nn.Module] = QuickGELU,
253
        quant_config: QuantizationConfig | None = None,
254
        prefix: str = "",
255
        use_data_parallel: bool = False,
256
257
    ):
        super().__init__()
258
259
260
261
262
263
264
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
265
        self.act = act_layer()
266
267
268
269
270
271
272
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
273
274
275
276
277
278
279
280

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


281
282
def apply_rotary_pos_emb_vision(
    t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
283
) -> torch.Tensor:
284
285
    rotary_emb_function = dispatch_rotary_emb_function(
        default=partial(apply_rotary_emb_torch, is_neox_style=True)
286
    )
287
    output = rotary_emb_function(t, cos, sin).type_as(t)
288
289
290
291
292
293
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
294
295
296
        embed_dim: int,
        num_heads: int,
        projection_size: int,
297
        quant_config: QuantizationConfig | None = None,
298
        prefix: str = "",
299
        use_data_parallel: bool = False,
300
        attn_backend_override: AttentionBackendEnum | None = None,
301
302
303
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
304
305
306
307
308
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
309
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
310
        self.hidden_size_per_attention_head = dist_utils.divide(
311
312
            projection_size, num_heads
        )
313
        self.num_attention_heads_per_partition = dist_utils.divide(
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
331
332

        # Detect attention implementation.
333
334
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
335
            dtype=torch.get_default_dtype(),
336
            attn_backend_override=attn_backend_override,
337
        )
338

339
340
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
341
                self.attn_backend,
342
                attn_backend_override=attn_backend_override,
343
            )
344
        )
345

346
        if self.attn_backend not in {
347
348
349
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
350
351
        }:
            raise RuntimeError(
352
353
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
354

355
        self.is_flash_attn_backend = self.attn_backend in {
356
357
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
358
        }
359

360
361
362
363
364
365
366
367
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
368
369
370
371
372
373
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
374
375
376
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

377
    def forward(
378
379
380
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
381
382
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
383
        max_seqlen: int | None = None,  # Only used for Flash Attention
384
    ) -> torch.Tensor:
385
386
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
387

388
389
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
390
391
        batch_size = q.shape[1]

392
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
393
394
395
396
397
398
399

        # [2 * b, s, heads, head_dim]
        qk_concat = torch.cat([q, k], dim=0)
        qk_rotated = apply_rotary_pos_emb_vision(
            qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
        )
        q, k = torch.chunk(qk_rotated, 2, dim=0)
400

401
        if self.is_flash_attn_backend:
402
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
403

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
419
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
420
            # Execute attention entry by entry for speed & less VRAM.
421
422
423
424
425
426
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
427
            outputs = []
428
429
430
431
432
433

            lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
            q_chunks = torch.split(q, lens, dim=1)
            k_chunks = torch.split(k, lens, dim=1)
            v_chunks = torch.split(v, lens, dim=1)
            for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
434
435
436
437
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
438
439
440
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
441
442
443
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
444
445
446
447
448
449
450
451
452
453
454

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
455
        act_layer: type[nn.Module] = QuickGELU,
456
457
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
458
        prefix: str = "",
459
        use_data_parallel: bool = False,
460
        attn_backend_override: AttentionBackendEnum | None = None,
461
462
463
464
465
466
467
468
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

469
470
471
472
473
474
475
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
476
            attn_backend_override=attn_backend_override,
477
478
479
480
481
482
483
484
485
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
486

487
    def forward(
488
489
490
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
491
492
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
493
        max_seqlen: int | None = None,  # Only used for Flash Attention
494
495
496
497
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
498
499
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
500
501
502
            max_seqlen=max_seqlen,
        )

503
504
505
506
507
508
509
510
511
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
512
        in_channels: int = 3,
513
514
515
516
517
518
519
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

520
        kernel_size = (temporal_patch_size, patch_size, patch_size)
521
522
        self.proj = Conv3dLayer(
            in_channels,
523
            embed_dim,
524
525
            kernel_size=kernel_size,
            stride=kernel_size,
526
527
            bias=False,
        )
528
529

    def forward(self, x: torch.Tensor) -> torch.Tensor:
530
531
532
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
533
534
535
536
537
538
539
540
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
541
        norm_layer: Callable[[int], nn.Module] | None = None,
542
        spatial_merge_size: int = 2,
543
        quant_config: QuantizationConfig | None = None,
544
        prefix: str = "",
545
        use_data_parallel: bool = False,
546
547
548
549
550
551
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
590
        quant_config: QuantizationConfig | None = None,
591
        prefix: str = "",
592
        use_data_parallel: bool = False,
593
        attn_backend_override: AttentionBackendEnum | None = None,
594
595
596
    ) -> None:
        super().__init__()

597
598
599
600
601
602
603
604
605
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
606

607
608
609
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

610
        self.spatial_merge_size = spatial_merge_size
611
612
        self.num_heads = num_heads
        self.embed_dim = embed_dim
613
614
615
616

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
617
            in_channels=in_channels,
618
619
620
621
622
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
623
624
625
626
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
627
            rope_parameters={"partial_rotary_factor": 0.5},
628
        )
629

630
631
632
633
634
635
636
637
638
639
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
640
                    attn_backend_override=attn_backend_override,
641
642
643
644
                )
                for layer_idx in range(depth)
            ]
        )
645
646
647
648
649
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
650
            prefix=f"{prefix}.merger",
651
            use_data_parallel=use_data_parallel,
652
        )
653
        self.attn_backend = get_vit_attn_backend(
654
655
656
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
657
        )
658
659
660

    @property
    def dtype(self) -> torch.dtype:
661
        return self.patch_embed.proj.weight.dtype
662
663
664

    @property
    def device(self) -> torch.device:
665
        return self.patch_embed.proj.weight.device
666

667
668
669
    def rot_pos_emb(
        self, grid_thw: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
670
        pos_ids = []
671
        max_grid_size = 0
672
673
674
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
696
            max_grid_size = max(max_grid_size, h, w)
697
        pos_ids = torch.cat(pos_ids, dim=0)
698
699
700
701

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

702
703
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
704
        return cos_combined, sin_combined
705

706
707
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
708
709
710
711
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
712
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
713
        return max_seqlen
714

715
716
717
    def forward(
        self,
        x: torch.Tensor,
718
        grid_thw: torch.Tensor | list[list[int]],
719
720
721
722
723
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

724
725
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
726
            grid_thw = np.array(grid_thw, dtype=np.int32)
727
728
        else:
            grid_thw_list = grid_thw.tolist()
729
            grid_thw = grid_thw.numpy()
730

731
        # compute position embedding
732
        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
733
734

        # compute cu_seqlens
735
736
737
738
739
        cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            axis=0, dtype=np.int32
        )
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens)
740
741
742

        # transformers
        x = x.unsqueeze(1)
743

744
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
745
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
746
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
747
        for blk in self.blocks:
748
749
750
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
751
752
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
753
754
                max_seqlen=max_seqlen,
            )
755
756
757

        # adapter
        x = self.merger(x)
758

759
760
        return x

761
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
762
763
764
765
766
767
768
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
769
        loaded_params: set[str] = set()
770
771

        for name, loaded_weight in weights:
772
            for param_name, weight_name, shard_id in stacked_params_mapping:
773
774
775
776
777
778
779
780
781
782
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
783
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
784
785
786
787
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

788

789
def _create_qwen2vl_field_factory(
790
    spatial_merge_size: int,
791
792
) -> Callable[
    [Mapping[str, torch.Tensor]],
793
    Mapping[str, MultiModalFieldConfig],
794
795
796
797
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
798
799
800
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
801
802
803

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
804
805
806
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
807
808
809

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
810
811
                "image", image_pixel_grid_sizes
            ),
812
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
813
814
                "image", image_embed_grid_sizes
            ),
815
            image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
816
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
817
818
                "video", video_grid_sizes
            ),
819
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
820
821
                "video", video_embed_grid_sizes
            ),
822
            video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
823
824
825
        )

    return _qwen2vl_field_config
826

827

Roger Wang's avatar
Roger Wang committed
828
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
829
830
831
832
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

833
834
    def _parse_image_data(
        self,
835
836
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
837
        if isinstance(data, dict):
838
839
840
841
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
842
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
843
            )
844
845
846
847

        return super()._parse_image_data(data)

    def _parse_video_data(
848
        self,
849
850
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
851
        if isinstance(data, dict):
852
853
854
855
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
856
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
857
            )
858
859
860
861

        return super()._parse_video_data(data)


862
863
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
864
865
        return self.ctx.get_hf_config(Qwen2VLConfig)

866
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
867
868
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
869
            use_fast=kwargs.pop("use_fast", True),
870
871
872
            **kwargs,
        )

873
874
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
875

876
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
877
878
        return {"image": None, "video": None}

879
880
881
882
883
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
884
885
886
887
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

888
889
890
891
892
893
894
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
895
        image_processor: Qwen2VLImageProcessor | None,
896
    ) -> tuple[ImageSize, int]:
897
898
899
900
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
901
        vision_config = hf_config.vision_config
902
903
904
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
905

906
907
908
909
910
911
912
913
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
914
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
915
        else:
916
            preprocessed_size = ImageSize(width=image_width, height=image_height)
917

918
919
920
921
922
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
923
924
925
926
927
928
929
930
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

931
    def get_num_image_tokens(
932
933
934
935
        self,
        *,
        image_width: int,
        image_height: int,
936
        image_processor: Qwen2VLImageProcessor | None,
937
938
939
940
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
941
            num_frames=1,
942
            image_processor=image_processor,
943
944
945
        )
        return num_image_tokens

946
    def get_num_video_tokens(
947
948
949
950
951
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
952
        image_processor: Qwen2VLImageProcessor | None,
953
954
955
956
957
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
958
            image_processor=image_processor,
959
960
961
        )
        return num_video_tokens

962
    def get_image_size_with_most_features(self) -> ImageSize:
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
        # NOTE: Simply processing a huge size with _get_vision_info might not give a
        # size that maximizes the number of featrues, i.e., the number of (merged)
        # patches. This is because the number of patches limits the allowed aspect
        # ratios. For example, suppose the maximum number of patches is 1280. A square
        # image cannot be broken down into 1280 patches, so feeding a giant square image
        # into _get_vision_info will not yield a size that maximizes the number of
        # patches. Therefore, we directly factorize the maximum number of patches into
        # height and width. The tricky part is to avoid extreme aspect ratios (>200 for
        # qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
        # patches and retry. This is safe because the processor does not accept extreme
        # aspect ratios, so there is no valid post-resize image with the number of
        # patches that yields extreme aspect ratios.

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        image_processor = self.get_image_processor()
        max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"]
        unit = patch_size * merge_size
        max_seq_len = max_pixels // (unit * unit)

        def closest_factor_pair(n: int) -> tuple[int, int]:
            # left <= right
            for d in range(math.isqrt(n), 0, -1):
                if n % d == 0:
                    return d, n // d
            return 1, n

        height_factor, width_factor = 1, max_seq_len
        for seq_len in range(max_seq_len, 0, -1):
            height_factor, width_factor = closest_factor_pair(seq_len)
            if width_factor / height_factor <= 200:
                break

        return ImageSize(width=unit * width_factor, height=unit * height_factor)
999

1000
1001
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1002

1003
        return self.get_num_image_tokens(
1004
1005
            image_width=target_width,
            image_height=target_height,
1006
            image_processor=None,
1007
        )
1008

1009
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1010
        target_width, target_height = self.get_image_size_with_most_features()
1011

1012
        num_frames = start_num_frames
1013
1014
1015

        while True:
            next_num_frames = num_frames + 1
1016
            next_max_tokens = self.get_num_video_tokens(
1017
1018
1019
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1020
                image_processor=None,
1021
            )
1022

1023
            if next_max_tokens > max_tokens:
1024
1025
1026
1027
1028
1029
                break

            num_frames = next_num_frames

        return num_frames

1030
1031
1032
1033
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1034
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1035
1036
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1037

1038
        max_total_frames = self._get_max_video_frames(seq_len)
1039
1040
1041
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1042

1043
        return max(max_frames_per_video, 1)
1044

1045
1046
1047
1048
1049
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1050
        target_width, target_height = self.get_image_size_with_most_features()
1051

1052
        return self.get_num_video_tokens(
1053
1054
            image_width=target_width,
            image_height=target_height,
1055
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1056
            image_processor=None,
1057
1058
        )

1059
1060

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1061
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1062
1063
1064
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1065
        hf_processor = self.info.get_hf_processor()
1066
1067
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1068

1069
1070
1071
1072
1073
1074
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1075
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1076
1077
1078
1079
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1080
1081
1082
1083
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1084

1085
1086
1087
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1088
        return {
1089
1090
1091
1092
1093
1094
1095
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1096
1097
                width=target_width,
                height=target_height,
1098
                num_frames=target_num_frames,
1099
                num_videos=num_videos,
1100
                overrides=video_overrides,
1101
            ),
1102
1103
        }

1104

1105
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1106
    def _get_data_parser(self) -> MultiModalDataParser:
1107
        return Qwen2VLMultiModalDataParser(
1108
1109
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1110

1111
    def _get_prompt_updates(
1112
1113
        self,
        mm_items: MultiModalDataItems,
1114
        hf_processor_mm_kwargs: Mapping[str, Any],
1115
        out_mm_kwargs: MultiModalKwargsItems,
1116
    ) -> Sequence[PromptUpdate]:
1117
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1118
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1119
1120
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1121
1122

        placeholder = {
1123
1124
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1125
        }
1126

1127
1128
1129
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1130
1131
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1132
1133
            assert isinstance(grid_thw, torch.Tensor)

1134
1135
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1136
1137
1138
1139

        return [
            PromptReplacement(
                modality=modality,
1140
                target=[placeholder[modality]],
1141
1142
1143
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1144
        ]
1145

1146
1147
1148
1149
1150
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1151
        return _create_qwen2vl_field_factory(
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1164
    # To ensure correct weight loading and mapping.
1165
1166
1167
1168
1169
1170
1171
1172
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1173
1174
        }
    )
1175

1176
1177
    supports_encoder_tp_data = True

1178
1179
1180
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1181
        mm_features: list[MultiModalFeatureSpec],
1182
    ) -> tuple[torch.Tensor, int]:
1183
1184
1185
1186
1187
1188
1189
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1190

1191
        hf_config = self.config
1192
1193
1194
1195
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1196
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1197
1198
1199

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1200
1201
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1228
                t, h, w = image_grid_thw[image_index]
1229
1230
1231
1232
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1233
                t, h, w = video_grid_thw[video_index]
1234
1235
1236
1237
1238
1239
1240
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1241
1242
1243
1244
1245
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1246
1247
            text_len = ed - st

1248
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1249
            llm_pos_ids_list.append(
1250
1251
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1252

1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1264

1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1277
            llm_pos_ids_list.append(
1278
1279
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1280
1281
1282
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1283
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1284
1285
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1286
1287
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1288
1289

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1290
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1291
1292
1293

        return llm_positions, mrope_position_delta

1294
    @classmethod
1295
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1296
1297
1298
1299
1300
1301
1302
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1303
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1304
        super().__init__()
1305
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1306
1307
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1308

1309
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1310
1311
1312
        self.config = config
        self.multimodal_config = multimodal_config

1313
1314
1315
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1316
1317
1318
1319
1320
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1321
1322
1323
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1324
                quant_config=quant_config,
1325
                prefix=maybe_prefix(prefix, "visual"),
1326
                use_data_parallel=self.use_data_parallel,
1327
                attn_backend_override=attn_backend_override,
1328
1329
1330
            )
        else:
            self.visual = None
1331

1332
1333
1334
1335
1336
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1337

1338
        self.make_empty_intermediate_tensors = (
1339
1340
            self.language_model.make_empty_intermediate_tensors
        )
1341
1342

    def _parse_and_validate_image_input(
1343
        self, **kwargs: object
1344
    ) -> Qwen2VLImageInputs | None:
1345
        pixel_values = kwargs.pop("pixel_values", None)
1346
        image_embeds = kwargs.pop("image_embeds", None)
1347
1348
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1349
        if pixel_values is None and image_embeds is None:
1350
1351
            return None

1352
        if pixel_values is not None:
1353
1354
1355
1356
1357
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1358
1359

        if image_embeds is not None:
1360
1361
1362
1363
1364
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1365
1366

    def _parse_and_validate_video_input(
1367
        self, **kwargs: object
1368
    ) -> Qwen2VLVideoInputs | None:
1369
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1370
        video_embeds = kwargs.pop("video_embeds", None)
1371
1372
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1373
        if pixel_values_videos is None and video_embeds is None:
1374
1375
            return None

1376
1377
1378
1379
1380
1381
1382
1383
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1384
1385
1386
1387
1388
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1389

1390
    def _process_image_input(
1391
1392
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1393
1394
1395
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1396
        if image_input["type"] == "image_embeds":
1397
            image_embeds = image_input["image_embeds"]
1398
        else:
1399
            pixel_values = image_input["pixel_values"]
1400
1401

            if self.use_data_parallel:
1402
                return run_dp_sharded_mrope_vision_model(
1403
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1404
                )
1405
            else:
1406
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1407
1408
1409

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1410
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1411
        return image_embeds.split(sizes)
1412
1413

    def _process_video_input(
1414
1415
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1416
1417
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1418

1419
        if video_input["type"] == "video_embeds":
1420
            video_embeds = video_input["video_embeds"]
1421
        else:
1422
            pixel_values_videos = video_input["pixel_values_videos"]
1423
            if self.use_data_parallel:
1424
                return run_dp_sharded_mrope_vision_model(
1425
1426
1427
1428
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
1429
                )
1430
            else:
1431
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1432

1433
1434
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1435
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1436
        return video_embeds.split(sizes)
1437
1438
1439
1440
1441
1442
1443

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1454
1455

        return modalities
1456

1457
1458
1459
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1460
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1461
1462
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1463
            return []
1464

1465
1466
1467
1468
1469
1470
1471
1472
1473
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1474
1475
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1476
1477
1478
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1479
                multimodal_embeddings += tuple(video_embeddings)
1480
1481
1482

        return multimodal_embeddings

1483
1484
1485
1486
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1487
1488
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1489
        **kwargs: object,
1490
    ) -> torch.Tensor | IntermediateTensors:
1491
1492
1493
1494
1495
1496
1497
1498
1499
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1500
1501
1502
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1503
        """
1504

1505
        if intermediate_tensors is not None:
1506
            inputs_embeds = None
1507

1508
        hidden_states = self.language_model.model(
1509
1510
            input_ids=input_ids,
            positions=positions,
1511
            intermediate_tensors=intermediate_tensors,
1512
1513
1514
1515
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1516
1517
1518
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1519
    ) -> torch.Tensor | None:
1520
        return self.language_model.compute_logits(hidden_states)
1521

1522
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1523
1524
1525
1526
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1527
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1528
1529
1530
1531
1532
1533
1534

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1535
1536
1537
            connector="visual.merger.",
            tower_model="visual.",
        )
1538
1539
1540
1541
1542
1543
1544
1545
1546


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1547
        size: dict[str, int] | None = None,
1548
1549
1550
1551
1552
1553
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1554
                "longest_edge": size["max_pixels"],
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
1565
        tokenizer: TokenizerLike,
1566
1567
1568
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1569
1570
1571
1572
1573
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1574
1575
            **kwargs,
        )
1576
1577
1578
1579
1580


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1581
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1593
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1594
1595


1596
1597
1598
1599
1600
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1601
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1602
1603
1604
1605
1606
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1607

1608
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1609
1610
1611
1612
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1613
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)