qwen2_vl.py 56.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32
33
34
35
36

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
37
from transformers import BatchFeature
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.attention.backends.registry import AttentionBackendEnum
47
48
49
50
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
51
from vllm.config import VllmConfig
52
from vllm.config.multimodal import BaseDummyOptions
53
from vllm.distributed import parallel_state
54
55
56
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
57
58
59
60
61
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.rotary_embedding.common import (
64
65
    dispatch_rotary_emb_function,
)
66
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
67
from vllm.model_executor.models.module_mapping import MultiModelKeys
68
from vllm.multimodal import MULTIMODAL_REGISTRY
69
70
71
72
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
73
    MultiModalFeatureSpec,
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
91
from vllm.multimodal.profiling import BaseDummyInputsBuilder
92
from vllm.sequence import IntermediateTensors
93
from vllm.transformers_utils.tokenizer import AnyTokenizer
94
from vllm.utils.tensor_schema import TensorSchema, TensorShape
95

96
97
98
99
100
101
102
103
104
105
106
107
108
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
109
110
111
112
113
from .vision import (
    conv3d_to_linear_weight,
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
114

115
116
logger = init_logger(__name__)

117
# For profile run
118
_MAX_FRAMES_PER_VIDEO = 14
119

120
121
122
# === Vision Inputs === #


123
class Qwen2VLImagePixelInputs(TensorSchema):
124
    """
125
126
127
128
129
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
130

131
    Historical context:
132
        - pixel_values shape: (num_patches, num_channels * patch_size *
133
134
135
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
136
    """
137

138
    type: Literal["pixel_values"]
139

140
141
142
143
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
144

145
146
147
148
149
150
151
152
153
154
155
156
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
157

158
159
160
161
162
163
164
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
165
    """
166

167
    type: Literal["image_embeds"]
168

169
170
171
172
173
174
175
176
177
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
178
179


180
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
181
182


183
184
185
186
187
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
188
        - ctps: Number of channels * temporal_patch_size * patch_size *
189
190
          patch_size
        - nv: Number of videos
191

192
    Historical context:
193
        - pixel_values_videos shape: (num_patches, num_channels *
194
195
196
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
197
    """
198

199
    type: Literal["pixel_values_videos"]
200

201
202
203
204
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
205

206
207
208
209
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
210
211


212
213
214
215
216
217
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
218

219
220
221
222
223
224
225
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
226
    """
227

228
    type: Literal["video_embeds"]
229

230
231
232
233
234
235
236
237
238
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
239
240


241
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
242

243
244
245
246
247
248
249
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
250
        hidden_features: int,
251
        act_layer: type[nn.Module] = QuickGELU,
252
        quant_config: QuantizationConfig | None = None,
253
        prefix: str = "",
254
        use_data_parallel: bool = False,
255
256
    ):
        super().__init__()
257
258
259
260
261
262
263
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
264
        self.act = act_layer()
265
266
267
268
269
270
271
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
286
287
288
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
289
290


291
292
293
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
294
295
296
297
298
299
300
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
301
302
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
303
    sin = repeat(
304
305
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
306
307
    return torch.cat(
        [
308
309
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
310
311
312
313
314
        ],
        dim=-1,
    )


315
316
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
317
318
319
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
320
    output = rotary_emb_function(t_, cos, sin).type_as(t)
321
322
323
324
325
326
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
327
328
329
        embed_dim: int,
        num_heads: int,
        projection_size: int,
330
        quant_config: QuantizationConfig | None = None,
331
        prefix: str = "",
332
        use_data_parallel: bool = False,
333
        attn_backend_override: AttentionBackendEnum | None = None,
334
335
336
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
337
338
339
340
341
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
342
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
343
        self.hidden_size_per_attention_head = dist_utils.divide(
344
345
            projection_size, num_heads
        )
346
        self.num_attention_heads_per_partition = dist_utils.divide(
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
364
365

        # Detect attention implementation.
366
367
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
368
            dtype=torch.get_default_dtype(),
369
            attn_backend_override=attn_backend_override,
370
        )
371
        self.use_upstream_fa = False
372

373
374
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
375
376
                self.attn_backend,
                self.use_upstream_fa,
377
                attn_backend_override=attn_backend_override,
378
            )
379
        )
380

381
        if self.attn_backend not in {
382
383
384
385
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.XFORMERS,
            AttentionBackendEnum.ROCM_AITER_FA,
386
387
        }:
            raise RuntimeError(
388
389
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
390

391
        self.is_flash_attn_backend = self.attn_backend in {
392
393
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
394
        }
395

396
397
398
399
400
401
402
403
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
404
405
406
407
408
409
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
410
411
412
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

413
    def forward(
414
415
416
417
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
418
419
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
420
    ) -> torch.Tensor:
421
422
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
423

424
425
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
426
427
        batch_size = q.shape[1]

428
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
429
        if rotary_pos_emb is not None:
430
431
432
433
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
434

435
        if self.is_flash_attn_backend:
436
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
453
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
燃's avatar
committed
454
            # Execute attention entry by entry for speed & less VRAM.
455
456
457
458
459
460
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
461
            outputs = []
462
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
463
464
465
466
467
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
468
469
470
471
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
472
473
474
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
475
476
477
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
478
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
479
480
481
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

482
483
484
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
485
486

            context_layer = xops.memory_efficient_attention_forward(
487
488
489
490
491
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
492
493
494
495
496
497
498
499
500
501
502

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
503
        act_layer: type[nn.Module] = QuickGELU,
504
505
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
506
        prefix: str = "",
507
        use_data_parallel: bool = False,
508
        attn_backend_override: AttentionBackendEnum | None = None,
509
510
511
512
513
514
515
516
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

517
518
519
520
521
522
523
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
524
            attn_backend_override=attn_backend_override,
525
526
527
528
529
530
531
532
533
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
534

535
    def forward(
536
537
538
539
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
540
541
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
542
543
544
545
546
547
548
549
550
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

551
552
553
554
555
556
557
558
559
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
560
        in_channels: int = 3,
561
562
563
564
565
566
567
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

568
        kernel_size = (temporal_patch_size, patch_size, patch_size)
569
570
        self.proj = ReplicatedLinear(
            in_channels * math.prod(kernel_size),
571
572
            embed_dim,
            bias=False,
573
            return_bias=False,
574
        )
575
576

    def forward(self, x: torch.Tensor) -> torch.Tensor:
577
        x = self.proj(x)
578
579
580
581
582
583
584
585
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
586
        norm_layer: Callable[[int], nn.Module] | None = None,
587
        spatial_merge_size: int = 2,
588
        quant_config: QuantizationConfig | None = None,
589
        prefix: str = "",
590
        use_data_parallel: bool = False,
591
592
593
594
595
596
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
635
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
636
637
638
639
640
641
642
643
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
644
645
646
647
648
649
650
651
652
653
654
655
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
656
657
658
659
660
661
662
663
664
665
666
667
668
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
669
        quant_config: QuantizationConfig | None = None,
670
        prefix: str = "",
671
        use_data_parallel: bool = False,
672
        attn_backend_override: AttentionBackendEnum | None = None,
673
674
675
    ) -> None:
        super().__init__()

676
677
678
679
680
681
682
683
684
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
685

686
687
688
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

689
        self.spatial_merge_size = spatial_merge_size
690
691
        self.num_heads = num_heads
        self.embed_dim = embed_dim
692
693
694
695

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
696
            in_channels=in_channels,
697
698
699
700
701
702
703
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

704
705
706
707
708
709
710
711
712
713
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
714
                    attn_backend_override=attn_backend_override,
715
716
717
718
                )
                for layer_idx in range(depth)
            ]
        )
719
720
721
722
723
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
724
            prefix=f"{prefix}.merger",
725
            use_data_parallel=use_data_parallel,
726
        )
727
        self.attn_backend = get_vit_attn_backend(
728
729
730
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
731
        )
732
733
734
        if (
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
735
        ):
736
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
737
738
739

    @property
    def dtype(self) -> torch.dtype:
740
        return self.patch_embed.proj.weight.dtype
741
742
743

    @property
    def device(self) -> torch.device:
744
        return self.patch_embed.proj.weight.device
745

746
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
747
        pos_ids = []
748
        max_grid_size = 0
749
750
751
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
773
            max_grid_size = max(max_grid_size, h, w)
774
775
776
777
778
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

779
    def compute_attn_mask_seqlen(
780
        self, cu_seqlens: torch.Tensor
781
    ) -> tuple[int | None, list[int] | None]:
782
        max_seqlen, seqlens = None, None
783
784
785
786
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
787
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
788
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
789
790
791
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

792
793
794
    def forward(
        self,
        x: torch.Tensor,
795
        grid_thw: torch.Tensor | list[list[int]],
796
797
798
799
800
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

801
802
803
804
805
806
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
        else:
            grid_thw_list = grid_thw.tolist()

807
        # compute position embedding
808
        rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
809
810

        # compute cu_seqlens
811
        cu_seqlens = torch.repeat_interleave(
812
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
813
        ).cumsum(dim=0, dtype=torch.int32)
814
815
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
816
817
818

        # transformers
        x = x.unsqueeze(1)
819

820
821
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
822
        for blk in self.blocks:
823
824
825
826
827
828
829
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
830
831
832

        # adapter
        x = self.merger(x)
833

834
835
        return x

836
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
837
838
839
840
841
842
843
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
844
        loaded_params: set[str] = set()
845
846

        for name, loaded_weight in weights:
847
848
849
            if name.endswith("patch_embed.proj.weight"):
                loaded_weight = conv3d_to_linear_weight(loaded_weight)

850
            for param_name, weight_name, shard_id in stacked_params_mapping:
851
852
853
854
855
856
857
858
859
860
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
861
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
862
863
864
865
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

866

867
def _create_qwen2vl_field_factory(
868
    spatial_merge_size: int,
869
870
) -> Callable[
    [Mapping[str, torch.Tensor]],
871
    Mapping[str, MultiModalFieldConfig],
872
873
874
875
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
876
877
878
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
879
880
881

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
882
883
884
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
885
886
887

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
888
889
                "image", image_pixel_grid_sizes
            ),
890
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
891
892
                "image", image_embed_grid_sizes
            ),
893
894
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
895
896
                "video", video_grid_sizes
            ),
897
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
898
899
                "video", video_embed_grid_sizes
            ),
900
901
902
903
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
904

905

Roger Wang's avatar
Roger Wang committed
906
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
907
908
909
910
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

911
912
    def _parse_image_data(
        self,
913
914
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
915
        if isinstance(data, dict):
916
917
918
919
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
920
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
921
            )
922
923
924
925

        return super()._parse_image_data(data)

    def _parse_video_data(
926
        self,
927
928
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
929
        if isinstance(data, dict):
930
931
932
933
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
934
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
935
            )
936
937
938
939

        return super()._parse_video_data(data)


940
941
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
942
943
        return self.ctx.get_hf_config(Qwen2VLConfig)

944
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
945
946
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
947
            use_fast=kwargs.pop("use_fast", True),
948
949
950
            **kwargs,
        )

951
952
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
953

954
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
955
956
        return {"image": None, "video": None}

957
958
959
960
961
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
962
963
964
965
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

966
967
968
969
970
971
972
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
973
        image_processor: Qwen2VLImageProcessor | None,
974
    ) -> tuple[ImageSize, int]:
975
976
977
978
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
979
        vision_config = hf_config.vision_config
980
981
982
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
983

984
985
986
987
988
989
990
991
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
992
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
993
        else:
994
            preprocessed_size = ImageSize(width=image_width, height=image_height)
995

996
997
998
999
1000
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
1001
1002
1003
1004
1005
1006
1007
1008
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

1009
    def get_num_image_tokens(
1010
1011
1012
1013
        self,
        *,
        image_width: int,
        image_height: int,
1014
        image_processor: Qwen2VLImageProcessor | None,
1015
1016
1017
1018
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
1019
            num_frames=1,
1020
            image_processor=image_processor,
1021
1022
1023
        )
        return num_image_tokens

1024
    def get_num_video_tokens(
1025
1026
1027
1028
1029
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1030
        image_processor: Qwen2VLImageProcessor | None,
1031
1032
1033
1034
1035
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1036
            image_processor=image_processor,
1037
1038
1039
        )
        return num_video_tokens

1040
    def get_image_size_with_most_features(self) -> ImageSize:
1041
1042
1043
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1044
            num_frames=1,
1045
            image_processor=None,
1046
1047
1048
        )
        return max_image_size

1049
1050
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1051

1052
        return self.get_num_image_tokens(
1053
1054
            image_width=target_width,
            image_height=target_height,
1055
            image_processor=None,
1056
        )
1057

1058
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1059
        target_width, target_height = self.get_image_size_with_most_features()
1060

1061
        num_frames = start_num_frames
1062
1063
1064

        while True:
            next_num_frames = num_frames + 1
1065
            next_max_tokens = self.get_num_video_tokens(
1066
1067
1068
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1069
                image_processor=None,
1070
            )
1071

1072
            if next_max_tokens > max_tokens:
1073
1074
1075
1076
1077
1078
                break

            num_frames = next_num_frames

        return num_frames

1079
1080
1081
1082
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1083
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1084
1085
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1086

1087
        max_total_frames = self._get_max_video_frames(seq_len)
1088
1089
1090
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1091

1092
        return max(max_frames_per_video, 1)
1093

1094
1095
1096
1097
1098
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1099
        target_width, target_height = self.get_image_size_with_most_features()
1100

1101
        return self.get_num_video_tokens(
1102
1103
            image_width=target_width,
            image_height=target_height,
1104
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1105
            image_processor=None,
1106
1107
        )

1108
1109

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1110
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1111
1112
1113
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1114
        hf_processor = self.info.get_hf_processor()
1115
1116
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1117

1118
1119
1120
1121
1122
1123
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1124
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1125
1126
1127
1128
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1129
1130
1131
1132
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1133

1134
1135
1136
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1137
        return {
1138
1139
1140
1141
1142
1143
1144
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1145
1146
                width=target_width,
                height=target_height,
1147
                num_frames=target_num_frames,
1148
                num_videos=num_videos,
1149
                overrides=video_overrides,
1150
            ),
1151
1152
        }

1153

1154
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1155
    def _get_data_parser(self) -> MultiModalDataParser:
1156
        return Qwen2VLMultiModalDataParser(
1157
1158
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1159

1160
    def _get_prompt_updates(
1161
1162
        self,
        mm_items: MultiModalDataItems,
1163
        hf_processor_mm_kwargs: Mapping[str, Any],
1164
        out_mm_kwargs: MultiModalKwargsItems,
1165
    ) -> Sequence[PromptUpdate]:
1166
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1167
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1168
1169
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1170
1171

        placeholder = {
1172
1173
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1174
        }
1175

1176
1177
1178
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1179
1180
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1181
1182
            assert isinstance(grid_thw, torch.Tensor)

1183
1184
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1185
1186
1187
1188

        return [
            PromptReplacement(
                modality=modality,
1189
                target=[placeholder[modality]],
1190
1191
1192
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1193
        ]
1194

1195
1196
1197
1198
1199
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1200
        return _create_qwen2vl_field_factory(
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1213
    merge_by_field_config = True
1214
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1215

1216
    # To ensure correct weight loading and mapping.
1217
1218
1219
1220
1221
1222
1223
1224
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1225
1226
        }
    )
1227

1228
1229
    supports_encoder_tp_data = True

1230
1231
1232
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1233
        mm_features: list[MultiModalFeatureSpec],
1234
    ) -> tuple[torch.Tensor, int]:
1235
1236
1237
1238
1239
1240
1241
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
        second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
1242

1243
        hf_config = self.config
1244
1245
1246
1247
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1248
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1249
1250
1251

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1252
1253
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
1280
                t, h, w = image_grid_thw[image_index]
1281
1282
1283
1284
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1285
                t, h, w = video_grid_thw[video_index]
1286
1287
1288
1289
1290
1291
1292
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1293
1294
1295
1296
1297
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1298
1299
            text_len = ed - st

1300
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1301
            llm_pos_ids_list.append(
1302
1303
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1304

1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1316

1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1329
            llm_pos_ids_list.append(
1330
1331
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1332
1333
1334
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1335
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1336
1337
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1338
1339
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1340
1341

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1342
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1343
1344
1345

        return llm_positions, mrope_position_delta

1346
    @classmethod
1347
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1348
1349
1350
1351
1352
1353
1354
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1355
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1356
        super().__init__()
1357
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1358
1359
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1360

1361
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1362
1363
1364
        self.config = config
        self.multimodal_config = multimodal_config

1365
1366
1367
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1368
1369
1370
1371
1372
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1373
1374
1375
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1376
                quant_config=quant_config,
1377
                prefix=maybe_prefix(prefix, "visual"),
1378
                use_data_parallel=self.use_data_parallel,
1379
                attn_backend_override=attn_backend_override,
1380
1381
1382
            )
        else:
            self.visual = None
1383

1384
1385
1386
1387
1388
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1389

1390
        self.make_empty_intermediate_tensors = (
1391
1392
            self.language_model.make_empty_intermediate_tensors
        )
1393
1394

    def _parse_and_validate_image_input(
1395
        self, **kwargs: object
1396
    ) -> Qwen2VLImageInputs | None:
1397
        pixel_values = kwargs.pop("pixel_values", None)
1398
        image_embeds = kwargs.pop("image_embeds", None)
1399
1400
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1401
        if pixel_values is None and image_embeds is None:
1402
1403
            return None

1404
        if pixel_values is not None:
1405
1406
1407
1408
1409
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1410
1411

        if image_embeds is not None:
1412
1413
1414
1415
1416
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1417
1418

    def _parse_and_validate_video_input(
1419
        self, **kwargs: object
1420
    ) -> Qwen2VLVideoInputs | None:
1421
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1422
        video_embeds = kwargs.pop("video_embeds", None)
1423
1424
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1425
        if pixel_values_videos is None and video_embeds is None:
1426
1427
            return None

1428
1429
1430
1431
1432
1433
1434
1435
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1436
1437
1438
1439
1440
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1441

1442
    def _process_image_input(
1443
1444
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1445
1446
1447
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1448
        if image_input["type"] == "image_embeds":
1449
            image_embeds = image_input["image_embeds"]
1450
        else:
1451
            pixel_values = image_input["pixel_values"]
1452
1453

            if self.use_data_parallel:
1454
                return run_dp_sharded_mrope_vision_model(
1455
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1456
                )
1457
            else:
1458
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1459
1460
1461

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1462
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1463
        return image_embeds.split(sizes)
1464
1465

    def _process_video_input(
1466
1467
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1468
1469
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1470

1471
        if video_input["type"] == "video_embeds":
1472
            video_embeds = video_input["video_embeds"]
1473
        else:
1474
            pixel_values_videos = video_input["pixel_values_videos"]
1475
            if self.use_data_parallel:
1476
                grid_thw_list = grid_thw.tolist()
1477
1478
1479
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1480
            else:
1481
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1482

1483
1484
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1485
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1486
        return video_embeds.split(sizes)
1487
1488
1489
1490
1491
1492
1493

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1504
1505

        return modalities
1506

1507
1508
1509
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1510
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1511
1512
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1513
            return []
1514

1515
1516
1517
1518
1519
1520
1521
1522
1523
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1524
1525
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1526
1527
1528
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1529
                multimodal_embeddings += tuple(video_embeddings)
1530
1531
1532

        return multimodal_embeddings

1533
1534
1535
1536
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1537
1538
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1539
        **kwargs: object,
1540
    ) -> torch.Tensor | IntermediateTensors:
1541
1542
1543
1544
1545
1546
1547
1548
1549
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1550
1551
1552
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1553
        """
1554

1555
        if intermediate_tensors is not None:
1556
            inputs_embeds = None
1557

1558
        hidden_states = self.language_model.model(
1559
1560
            input_ids=input_ids,
            positions=positions,
1561
            intermediate_tensors=intermediate_tensors,
1562
1563
1564
1565
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1566
1567
1568
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1569
    ) -> torch.Tensor | None:
1570
        return self.language_model.compute_logits(hidden_states)
1571

1572
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1573
1574
1575
1576
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1577
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1578
1579
1580
1581
1582
1583
1584

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1585
1586
1587
            connector="visual.merger.",
            tower_model="visual.",
        )
1588
1589
1590
1591
1592
1593
1594
1595
1596


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1597
        size: dict[str, int] | None = None,
1598
1599
1600
1601
1602
1603
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1604
                "longest_edge": size["max_pixels"],
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1619
1620
1621
1622
1623
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1624
1625
            **kwargs,
        )
1626
1627
1628
1629
1630


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1631
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1643
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1644
1645


1646
1647
1648
1649
1650
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1651
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1652
1653
1654
1655
1656
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1667
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1668
1669
1670
1671
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1672
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)