qwen2_vl.py 56.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31
32
33
34
35

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
36
from transformers import BatchFeature
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
38
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
39
40
41
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
42
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
43
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
44

45
from vllm.attention.backends.registry import AttentionBackendEnum
46
47
48
49
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.conv import Conv3dLayer
57
58
59
60
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
from vllm.model_executor.layers.rotary_embedding.common import (
63
64
    dispatch_rotary_emb_function,
)
65
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
66
from vllm.model_executor.models.module_mapping import MultiModelKeys
67
from vllm.multimodal import MULTIMODAL_REGISTRY
68
69
70
71
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
72
    MultiModalFeatureSpec,
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
90
from vllm.multimodal.profiling import BaseDummyInputsBuilder
91
from vllm.sequence import IntermediateTensors
92
from vllm.transformers_utils.tokenizer import AnyTokenizer
93
from vllm.utils.tensor_schema import TensorSchema, TensorShape
94

95
96
97
98
99
100
101
102
103
104
105
106
107
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
108
109
110
111
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
112

113
114
logger = init_logger(__name__)

115
# For profile run
116
_MAX_FRAMES_PER_VIDEO = 14
117

118
119
120
# === Vision Inputs === #


121
class Qwen2VLImagePixelInputs(TensorSchema):
122
    """
123
124
125
126
127
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
128

129
    Historical context:
130
        - pixel_values shape: (num_patches, num_channels * patch_size *
131
132
133
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
134
    """
135

136
    type: Literal["pixel_values"]
137

138
139
140
141
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
142

143
144
145
146
147
148
149
150
151
152
153
154
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
155

156
157
158
159
160
161
162
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
163
    """
164

165
    type: Literal["image_embeds"]
166

167
168
169
170
171
172
173
174
175
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
176
177


178
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
179
180


181
182
183
184
185
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
186
        - ctps: Number of channels * temporal_patch_size * patch_size *
187
188
          patch_size
        - nv: Number of videos
189

190
    Historical context:
191
        - pixel_values_videos shape: (num_patches, num_channels *
192
193
194
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
195
    """
196

197
    type: Literal["pixel_values_videos"]
198

199
200
201
202
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
203

204
205
206
207
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
208
209


210
211
212
213
214
215
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
216

217
218
219
220
221
222
223
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
224
    """
225

226
    type: Literal["video_embeds"]
227

228
229
230
231
232
233
234
235
236
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
237
238


239
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
240

241
242
243
244
245
246
247
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
248
        hidden_features: int,
249
        act_layer: type[nn.Module] = QuickGELU,
250
        quant_config: QuantizationConfig | None = None,
251
        prefix: str = "",
252
        use_data_parallel: bool = False,
253
254
    ):
        super().__init__()
255
256
257
258
259
260
261
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
262
        self.act = act_layer()
263
264
265
266
267
268
269
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
270
271
272
273
274
275
276
277
278
279
280
281
282
283

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
284
285
286
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
287
288


289
290
291
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
292
293
294
295
296
297
298
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
299
300
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
301
    sin = repeat(
302
303
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
304
305
    return torch.cat(
        [
306
307
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
308
309
310
311
312
        ],
        dim=-1,
    )


313
314
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
315
316
317
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
318
    output = rotary_emb_function(t_, cos, sin).type_as(t)
319
320
321
322
323
324
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
325
326
327
        embed_dim: int,
        num_heads: int,
        projection_size: int,
328
        quant_config: QuantizationConfig | None = None,
329
        prefix: str = "",
330
        use_data_parallel: bool = False,
331
        attn_backend_override: AttentionBackendEnum | None = None,
332
333
334
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
335
336
337
338
339
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
340
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
341
        self.hidden_size_per_attention_head = dist_utils.divide(
342
343
            projection_size, num_heads
        )
344
        self.num_attention_heads_per_partition = dist_utils.divide(
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
362
363

        # Detect attention implementation.
364
365
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
366
            dtype=torch.get_default_dtype(),
367
            attn_backend_override=attn_backend_override,
368
        )
369
        self.use_upstream_fa = False
370

371
372
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
373
374
                self.attn_backend,
                self.use_upstream_fa,
375
                attn_backend_override=attn_backend_override,
376
            )
377
        )
378

379
        if self.attn_backend not in {
380
381
382
383
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.XFORMERS,
            AttentionBackendEnum.ROCM_AITER_FA,
384
385
        }:
            raise RuntimeError(
386
387
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
388

389
        self.is_flash_attn_backend = self.attn_backend in {
390
391
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
392
        }
393

394
395
396
397
398
399
400
401
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
402
403
404
405
406
407
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
408
409
410
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

411
    def forward(
412
413
414
415
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
416
417
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
418
    ) -> torch.Tensor:
419
420
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
421

422
423
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
424
425
        batch_size = q.shape[1]

426
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
427
        if rotary_pos_emb is not None:
428
429
430
431
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
432

433
        if self.is_flash_attn_backend:
434
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
451
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
452
            # Execute attention entry by entry for speed & less VRAM.
453
454
455
456
457
458
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
459
            outputs = []
460
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
461
462
463
464
465
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
466
467
468
469
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
470
471
472
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
473
474
475
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
476
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
477
478
479
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

480
481
482
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
483
484

            context_layer = xops.memory_efficient_attention_forward(
485
486
487
488
489
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
490
491
492
493
494
495
496
497
498
499
500

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
501
        act_layer: type[nn.Module] = QuickGELU,
502
503
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
504
        prefix: str = "",
505
        use_data_parallel: bool = False,
506
        attn_backend_override: AttentionBackendEnum | None = None,
507
508
509
510
511
512
513
514
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

515
516
517
518
519
520
521
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
522
            attn_backend_override=attn_backend_override,
523
524
525
526
527
528
529
530
531
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
532

533
    def forward(
534
535
536
537
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
538
539
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
540
541
542
543
544
545
546
547
548
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

549
550
551
552
553
554
555
556
557
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
558
        in_channels: int = 3,
559
560
561
562
563
564
565
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

566
        kernel_size = (temporal_patch_size, patch_size, patch_size)
567
568
        self.proj = Conv3dLayer(
            in_channels,
569
            embed_dim,
570
571
            kernel_size=kernel_size,
            stride=kernel_size,
572
573
            bias=False,
        )
574
575

    def forward(self, x: torch.Tensor) -> torch.Tensor:
576
577
578
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
579
580
581
582
583
584
585
586
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
587
        norm_layer: Callable[[int], nn.Module] | None = None,
588
        spatial_merge_size: int = 2,
589
        quant_config: QuantizationConfig | None = None,
590
        prefix: str = "",
591
        use_data_parallel: bool = False,
592
593
594
595
596
597
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
636
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
637
638
639
640
641
642
643
644
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
645
646
647
648
649
650
651
652
653
654
655
656
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
657
658
659
660
661
662
663
664
665
666
667
668
669
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
670
        quant_config: QuantizationConfig | None = None,
671
        prefix: str = "",
672
        use_data_parallel: bool = False,
673
        attn_backend_override: AttentionBackendEnum | None = None,
674
675
676
    ) -> None:
        super().__init__()

677
678
679
680
681
682
683
684
685
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
686

687
688
689
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

690
        self.spatial_merge_size = spatial_merge_size
691
692
        self.num_heads = num_heads
        self.embed_dim = embed_dim
693
694
695
696

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
697
            in_channels=in_channels,
698
699
700
701
702
703
704
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

705
706
707
708
709
710
711
712
713
714
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
715
                    attn_backend_override=attn_backend_override,
716
717
718
719
                )
                for layer_idx in range(depth)
            ]
        )
720
721
722
723
724
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
725
            prefix=f"{prefix}.merger",
726
            use_data_parallel=use_data_parallel,
727
        )
728
        self.attn_backend = get_vit_attn_backend(
729
730
731
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
732
        )
733
734
735
        if (
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
736
        ):
737
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
738
739
740

    @property
    def dtype(self) -> torch.dtype:
741
        return self.patch_embed.proj.weight.dtype
742
743
744

    @property
    def device(self) -> torch.device:
745
        return self.patch_embed.proj.weight.device
746

747
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
748
        pos_ids = []
749
        max_grid_size = 0
750
751
752
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
774
            max_grid_size = max(max_grid_size, h, w)
775
776
777
778
779
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

780
    def compute_attn_mask_seqlen(
781
        self, cu_seqlens: torch.Tensor
782
    ) -> tuple[int | None, list[int] | None]:
783
        max_seqlen, seqlens = None, None
784
785
786
787
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
788
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
789
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
790
791
792
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

793
794
795
    def forward(
        self,
        x: torch.Tensor,
796
        grid_thw: torch.Tensor | list[list[int]],
797
798
799
800
801
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

802
803
804
805
806
807
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
        else:
            grid_thw_list = grid_thw.tolist()

808
        # compute position embedding
809
        rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
810
811

        # compute cu_seqlens
812
        cu_seqlens = torch.repeat_interleave(
813
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
814
        ).cumsum(dim=0, dtype=torch.int32)
815
816
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
817
818
819

        # transformers
        x = x.unsqueeze(1)
820

821
822
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
823
        for blk in self.blocks:
824
825
826
827
828
829
830
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
831
832
833

        # adapter
        x = self.merger(x)
834

835
836
        return x

837
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
838
839
840
841
842
843
844
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
845
        loaded_params: set[str] = set()
846
847

        for name, loaded_weight in weights:
848
            for param_name, weight_name, shard_id in stacked_params_mapping:
849
850
851
852
853
854
855
856
857
858
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
859
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
860
861
862
863
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

864

865
def _create_qwen2vl_field_factory(
866
    spatial_merge_size: int,
867
868
) -> Callable[
    [Mapping[str, torch.Tensor]],
869
    Mapping[str, MultiModalFieldConfig],
870
871
872
873
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
874
875
876
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
877
878
879

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
880
881
882
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
883
884
885

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
886
887
                "image", image_pixel_grid_sizes
            ),
888
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
889
890
                "image", image_embed_grid_sizes
            ),
891
892
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
893
894
                "video", video_grid_sizes
            ),
895
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
896
897
                "video", video_embed_grid_sizes
            ),
898
899
900
901
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
902

903

Roger Wang's avatar
Roger Wang committed
904
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
905
906
907
908
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

909
910
    def _parse_image_data(
        self,
911
912
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
913
        if isinstance(data, dict):
914
915
916
917
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
918
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
919
            )
920
921
922
923

        return super()._parse_image_data(data)

    def _parse_video_data(
924
        self,
925
926
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
927
        if isinstance(data, dict):
928
929
930
931
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
932
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
933
            )
934
935
936
937

        return super()._parse_video_data(data)


938
939
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
940
941
        return self.ctx.get_hf_config(Qwen2VLConfig)

942
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
943
944
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
945
            use_fast=kwargs.pop("use_fast", True),
946
947
948
            **kwargs,
        )

949
950
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
951

952
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
953
954
        return {"image": None, "video": None}

955
956
957
958
959
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
960
961
962
963
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

964
965
966
967
968
969
970
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
971
        image_processor: Qwen2VLImageProcessor | None,
972
    ) -> tuple[ImageSize, int]:
973
974
975
976
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
977
        vision_config = hf_config.vision_config
978
979
980
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
981

982
983
984
985
986
987
988
989
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
990
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
991
        else:
992
            preprocessed_size = ImageSize(width=image_width, height=image_height)
993

994
995
996
997
998
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
999
1000
1001
1002
1003
1004
1005
1006
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

1007
    def get_num_image_tokens(
1008
1009
1010
1011
        self,
        *,
        image_width: int,
        image_height: int,
1012
        image_processor: Qwen2VLImageProcessor | None,
1013
1014
1015
1016
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
1017
            num_frames=1,
1018
            image_processor=image_processor,
1019
1020
1021
        )
        return num_image_tokens

1022
    def get_num_video_tokens(
1023
1024
1025
1026
1027
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1028
        image_processor: Qwen2VLImageProcessor | None,
1029
1030
1031
1032
1033
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1034
            image_processor=image_processor,
1035
1036
1037
        )
        return num_video_tokens

1038
    def get_image_size_with_most_features(self) -> ImageSize:
1039
1040
1041
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1042
            num_frames=1,
1043
            image_processor=None,
1044
1045
1046
        )
        return max_image_size

1047
1048
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1049

1050
        return self.get_num_image_tokens(
1051
1052
            image_width=target_width,
            image_height=target_height,
1053
            image_processor=None,
1054
        )
1055

1056
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1057
        target_width, target_height = self.get_image_size_with_most_features()
1058

1059
        num_frames = start_num_frames
1060
1061
1062

        while True:
            next_num_frames = num_frames + 1
1063
            next_max_tokens = self.get_num_video_tokens(
1064
1065
1066
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1067
                image_processor=None,
1068
            )
1069

1070
            if next_max_tokens > max_tokens:
1071
1072
1073
1074
1075
1076
                break

            num_frames = next_num_frames

        return num_frames

1077
1078
1079
1080
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1081
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1082
1083
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1084

1085
        max_total_frames = self._get_max_video_frames(seq_len)
1086
1087
1088
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1089

1090
        return max(max_frames_per_video, 1)
1091

1092
1093
1094
1095
1096
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1097
        target_width, target_height = self.get_image_size_with_most_features()
1098

1099
        return self.get_num_video_tokens(
1100
1101
            image_width=target_width,
            image_height=target_height,
1102
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1103
            image_processor=None,
1104
1105
        )

1106
1107

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1108
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1109
1110
1111
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1112
        hf_processor = self.info.get_hf_processor()
1113
1114
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1115

1116
1117
1118
1119
1120
1121
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1122
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1123
1124
1125
1126
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1127
1128
1129
1130
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1131

1132
1133
1134
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1135
        return {
1136
1137
1138
1139
1140
1141
1142
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1143
1144
                width=target_width,
                height=target_height,
1145
                num_frames=target_num_frames,
1146
                num_videos=num_videos,
1147
                overrides=video_overrides,
1148
            ),
1149
1150
        }

1151

1152
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1153
    def _get_data_parser(self) -> MultiModalDataParser:
1154
        return Qwen2VLMultiModalDataParser(
1155
1156
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1157

1158
    def _get_prompt_updates(
1159
1160
        self,
        mm_items: MultiModalDataItems,
1161
        hf_processor_mm_kwargs: Mapping[str, Any],
1162
        out_mm_kwargs: MultiModalKwargsItems,
1163
    ) -> Sequence[PromptUpdate]:
1164
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1165
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1166
1167
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1168
1169

        placeholder = {
1170
1171
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1172
        }
1173

1174
1175
1176
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1177
1178
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1179
1180
            assert isinstance(grid_thw, torch.Tensor)

1181
1182
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1183
1184
1185
1186

        return [
            PromptReplacement(
                modality=modality,
1187
                target=[placeholder[modality]],
1188
1189
1190
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1191
        ]
1192

1193
1194
1195
1196
1197
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1198
        return _create_qwen2vl_field_factory(
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1211
    merge_by_field_config = True
1212
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1213

1214
    # To ensure correct weight loading and mapping.
1215
1216
1217
1218
1219
1220
1221
1222
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1223
1224
        }
    )
1225

1226
1227
    supports_encoder_tp_data = True

1228
1229
1230
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1231
        mm_features: list[MultiModalFeatureSpec],
1232
    ) -> tuple[torch.Tensor, int]:
1233
1234
1235
1236
1237
1238
1239
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1240

1241
        hf_config = self.config
1242
1243
1244
1245
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1246
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1247
1248
1249

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1250
1251
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1278
                t, h, w = image_grid_thw[image_index]
1279
1280
1281
1282
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1283
                t, h, w = video_grid_thw[video_index]
1284
1285
1286
1287
1288
1289
1290
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1291
1292
1293
1294
1295
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1296
1297
            text_len = ed - st

1298
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1299
            llm_pos_ids_list.append(
1300
1301
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1302

1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1314

1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1327
            llm_pos_ids_list.append(
1328
1329
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1330
1331
1332
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1333
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1334
1335
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1336
1337
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1338
1339

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1340
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1341
1342
1343

        return llm_positions, mrope_position_delta

1344
    @classmethod
1345
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1346
1347
1348
1349
1350
1351
1352
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1353
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1354
        super().__init__()
1355
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1356
1357
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1358

1359
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1360
1361
1362
        self.config = config
        self.multimodal_config = multimodal_config

1363
1364
1365
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1366
1367
1368
1369
1370
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1371
1372
1373
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1374
                quant_config=quant_config,
1375
                prefix=maybe_prefix(prefix, "visual"),
1376
                use_data_parallel=self.use_data_parallel,
1377
                attn_backend_override=attn_backend_override,
1378
1379
1380
            )
        else:
            self.visual = None
1381

1382
1383
1384
1385
1386
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1387

1388
        self.make_empty_intermediate_tensors = (
1389
1390
            self.language_model.make_empty_intermediate_tensors
        )
1391
1392

    def _parse_and_validate_image_input(
1393
        self, **kwargs: object
1394
    ) -> Qwen2VLImageInputs | None:
1395
        pixel_values = kwargs.pop("pixel_values", None)
1396
        image_embeds = kwargs.pop("image_embeds", None)
1397
1398
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1399
        if pixel_values is None and image_embeds is None:
1400
1401
            return None

1402
        if pixel_values is not None:
1403
1404
1405
1406
1407
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1408
1409

        if image_embeds is not None:
1410
1411
1412
1413
1414
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1415
1416

    def _parse_and_validate_video_input(
1417
        self, **kwargs: object
1418
    ) -> Qwen2VLVideoInputs | None:
1419
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1420
        video_embeds = kwargs.pop("video_embeds", None)
1421
1422
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1423
        if pixel_values_videos is None and video_embeds is None:
1424
1425
            return None

1426
1427
1428
1429
1430
1431
1432
1433
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1434
1435
1436
1437
1438
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1439

1440
    def _process_image_input(
1441
1442
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1443
1444
1445
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1446
        if image_input["type"] == "image_embeds":
1447
            image_embeds = image_input["image_embeds"]
1448
        else:
1449
            pixel_values = image_input["pixel_values"]
1450
1451

            if self.use_data_parallel:
1452
                return run_dp_sharded_mrope_vision_model(
1453
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1454
                )
1455
            else:
1456
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1457
1458
1459

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1460
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1461
        return image_embeds.split(sizes)
1462
1463

    def _process_video_input(
1464
1465
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1466
1467
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1468

1469
        if video_input["type"] == "video_embeds":
1470
            video_embeds = video_input["video_embeds"]
1471
        else:
1472
            pixel_values_videos = video_input["pixel_values_videos"]
1473
            if self.use_data_parallel:
1474
                grid_thw_list = grid_thw.tolist()
1475
1476
1477
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1478
            else:
1479
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1480

1481
1482
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1483
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1484
        return video_embeds.split(sizes)
1485
1486
1487
1488
1489
1490
1491

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1502
1503

        return modalities
1504

1505
1506
1507
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1508
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1509
1510
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1511
            return []
1512

1513
1514
1515
1516
1517
1518
1519
1520
1521
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1522
1523
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1524
1525
1526
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1527
                multimodal_embeddings += tuple(video_embeddings)
1528
1529
1530

        return multimodal_embeddings

1531
1532
1533
1534
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1535
1536
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1537
        **kwargs: object,
1538
    ) -> torch.Tensor | IntermediateTensors:
1539
1540
1541
1542
1543
1544
1545
1546
1547
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1548
1549
1550
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1551
        """
1552

1553
        if intermediate_tensors is not None:
1554
            inputs_embeds = None
1555

1556
        hidden_states = self.language_model.model(
1557
1558
            input_ids=input_ids,
            positions=positions,
1559
            intermediate_tensors=intermediate_tensors,
1560
1561
1562
1563
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1564
1565
1566
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1567
    ) -> torch.Tensor | None:
1568
        return self.language_model.compute_logits(hidden_states)
1569

1570
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1571
1572
1573
1574
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1575
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1576
1577
1578
1579
1580
1581
1582

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1583
1584
1585
            connector="visual.merger.",
            tower_model="visual.",
        )
1586
1587
1588
1589
1590
1591
1592
1593
1594


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1595
        size: dict[str, int] | None = None,
1596
1597
1598
1599
1600
1601
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1602
                "longest_edge": size["max_pixels"],
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1617
1618
1619
1620
1621
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1622
1623
            **kwargs,
        )
1624
1625
1626
1627
1628


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1629
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1641
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1642
1643


1644
1645
1646
1647
1648
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1649
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1650
1651
1652
1653
1654
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1665
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1666
1667
1668
1669
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1670
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)