qwen2_vl.py 59.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32
33
34
35
36

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
37
from transformers import AutoConfig, BatchFeature, PretrainedConfig
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.attention.backends.registry import _Backend
47
48
49
50
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
51
from vllm.config import VllmConfig
52
from vllm.config.multimodal import BaseDummyOptions
53
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
54
55
56
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
57
58
59
60
61
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.rotary_embedding.common import (
64
65
    dispatch_rotary_emb_function,
)
66
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
67
from vllm.model_executor.models.module_mapping import MultiModelKeys
68
from vllm.multimodal import MULTIMODAL_REGISTRY
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
90
from vllm.multimodal.profiling import BaseDummyInputsBuilder
91
from vllm.sequence import IntermediateTensors
92
from vllm.transformers_utils.tokenizer import AnyTokenizer
93
from vllm.utils.tensor_schema import TensorSchema, TensorShape
94

95
96
97
98
99
100
101
102
103
104
105
106
107
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
108
109
110
111
112
from .vision import (
    conv3d_to_linear_weight,
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
113

114
115
logger = init_logger(__name__)

116
# For profile run
117
_MAX_FRAMES_PER_VIDEO = 14
118

119
120
121
# === Vision Inputs === #


122
class Qwen2VLImagePixelInputs(TensorSchema):
123
    """
124
125
126
127
128
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
129

130
    Historical context:
131
        - pixel_values shape: (num_patches, num_channels * patch_size *
132
133
134
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
135
    """
136

137
    type: Literal["pixel_values"]
138

139
140
141
142
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
143

144
145
146
147
148
149
150
151
152
153
154
155
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
156

157
158
159
160
161
162
163
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
164
    """
165

166
    type: Literal["image_embeds"]
167

168
169
170
171
172
173
174
175
176
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
177
178


179
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
180
181


182
183
184
185
186
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
187
        - ctps: Number of channels * temporal_patch_size * patch_size *
188
189
          patch_size
        - nv: Number of videos
190

191
    Historical context:
192
        - pixel_values_videos shape: (num_patches, num_channels *
193
194
195
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
196
    """
197

198
    type: Literal["pixel_values_videos"]
199

200
201
202
203
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
204

205
206
207
208
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
209
210


211
212
213
214
215
216
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
217

218
219
220
221
222
223
224
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
225
    """
226

227
    type: Literal["video_embeds"]
228

229
230
231
232
233
234
235
236
237
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
238
239


240
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
241

242
243
244
245
246
247
248
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
249
        hidden_features: int,
250
        act_layer: type[nn.Module] = QuickGELU,
251
        quant_config: QuantizationConfig | None = None,
252
        prefix: str = "",
253
        use_data_parallel: bool = False,
254
255
    ):
        super().__init__()
256
257
258
259
260
261
262
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
263
        self.act = act_layer()
264
265
266
267
268
269
270
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
285
286
287
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
288
289


290
291
292
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
293
294
295
296
297
298
299
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
300
301
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
302
    sin = repeat(
303
304
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
305
306
    return torch.cat(
        [
307
308
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
309
310
311
312
313
        ],
        dim=-1,
    )


314
315
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
316
317
318
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
319
    output = rotary_emb_function(t_, cos, sin).type_as(t)
320
321
322
323
324
325
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
326
327
328
        embed_dim: int,
        num_heads: int,
        projection_size: int,
329
        quant_config: QuantizationConfig | None = None,
330
        prefix: str = "",
331
        use_data_parallel: bool = False,
332
        attn_backend_override: _Backend | None = None,
333
334
335
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
336
337
338
339
340
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
341
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
342
        self.hidden_size_per_attention_head = dist_utils.divide(
343
344
            projection_size, num_heads
        )
345
        self.num_attention_heads_per_partition = dist_utils.divide(
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
363
364

        # Detect attention implementation.
365
366
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
367
            dtype=torch.get_default_dtype(),
368
            attn_backend_override=attn_backend_override,
369
        )
370
        self.use_upstream_fa = False
371

372
373
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
374
375
                self.attn_backend,
                self.use_upstream_fa,
376
                attn_backend_override=attn_backend_override,
377
            )
378
        )
379

380
        if self.attn_backend not in {
381
382
383
384
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
385
386
        }:
            raise RuntimeError(
387
388
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
389

390
        self.is_flash_attn_backend = self.attn_backend in {
391
392
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
393
        }
394

395
396
397
398
399
400
401
402
403
404
405
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
406
407
408
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
409
410
411
412
413
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
414
415
416
417
418
419
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
420
421
422
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

423
    def forward(
424
425
426
427
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
428
429
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
430
    ) -> torch.Tensor:
431
432
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
433

434
435
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
436
437
        batch_size = q.shape[1]

438
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
439
        if rotary_pos_emb is not None:
440
441
442
443
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
444

445
        if self.is_flash_attn_backend:
446
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
447

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
463
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
464
465
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
466
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
467
468
469
470
471
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
472
473
474
475
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
476
477
478
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
479
480
481
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
482
        elif self.attn_backend == _Backend.XFORMERS:
483
484
485
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

486
487
488
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
489
490

            context_layer = xops.memory_efficient_attention_forward(
491
492
493
494
495
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
496
497
498
499
500
501
502
503
504
505
506

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
507
        act_layer: type[nn.Module] = QuickGELU,
508
509
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
510
        prefix: str = "",
511
        use_data_parallel: bool = False,
512
        attn_backend_override: _Backend | None = None,
513
514
515
516
517
518
519
520
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

521
522
523
524
525
526
527
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
528
            attn_backend_override=attn_backend_override,
529
530
531
532
533
534
535
536
537
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
538

539
    def forward(
540
541
542
543
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
544
545
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
546
547
548
549
550
551
552
553
554
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

555
556
557
558
559
560
561
562
563
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
564
        in_channels: int = 3,
565
566
567
568
569
570
571
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

572
        kernel_size = (temporal_patch_size, patch_size, patch_size)
573
574
        self.proj = ReplicatedLinear(
            in_channels * math.prod(kernel_size),
575
576
            embed_dim,
            bias=False,
577
            return_bias=False,
578
        )
579
580

    def forward(self, x: torch.Tensor) -> torch.Tensor:
581
        x = self.proj(x)
582
583
584
585
586
587
588
589
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
590
        norm_layer: Callable[[int], nn.Module] | None = None,
591
        spatial_merge_size: int = 2,
592
        quant_config: QuantizationConfig | None = None,
593
        prefix: str = "",
594
        use_data_parallel: bool = False,
595
596
597
598
599
600
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
639
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
640
641
642
643
644
645
646
647
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
648
649
650
651
652
653
654
655
656
657
658
659
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
660
661
662
663
664
665
666
667
668
669
670
671
672
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
673
        quant_config: QuantizationConfig | None = None,
674
        prefix: str = "",
675
        use_data_parallel: bool = False,
676
        attn_backend_override: _Backend | None = None,
677
678
679
    ) -> None:
        super().__init__()

680
681
682
683
684
685
686
687
688
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
689

690
691
692
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

693
        self.spatial_merge_size = spatial_merge_size
694
695
        self.num_heads = num_heads
        self.embed_dim = embed_dim
696
697
698
699

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
700
            in_channels=in_channels,
701
702
703
704
705
706
707
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

708
709
710
711
712
713
714
715
716
717
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
718
                    attn_backend_override=attn_backend_override,
719
720
721
722
                )
                for layer_idx in range(depth)
            ]
        )
723
724
725
726
727
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
728
            prefix=f"{prefix}.merger",
729
            use_data_parallel=use_data_parallel,
730
        )
731
        self.attn_backend = get_vit_attn_backend(
732
733
734
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
735
736
737
738
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
739
            self.attn_backend = _Backend.FLASH_ATTN
740
741
742

    @property
    def dtype(self) -> torch.dtype:
743
        return self.patch_embed.proj.weight.dtype
744
745
746

    @property
    def device(self) -> torch.device:
747
        return self.patch_embed.proj.weight.device
748

749
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
750
        pos_ids = []
751
        max_grid_size = 0
752
753
754
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
776
            max_grid_size = max(max_grid_size, h, w)
777
778
779
780
781
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

782
    def compute_attn_mask_seqlen(
783
        self, cu_seqlens: torch.Tensor
784
    ) -> tuple[int | None, list[int] | None]:
785
        max_seqlen, seqlens = None, None
786
787
788
789
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
790
791
792
793
794
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

795
796
797
    def forward(
        self,
        x: torch.Tensor,
798
        grid_thw: list[list[int]],
799
800
801
802
803
804
805
806
807
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
808
        grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
809
810
811
        cu_seqlens = torch.repeat_interleave(
            grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
812
813
814
815
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
816

817
818
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
819
        for blk in self.blocks:
820
821
822
823
824
825
826
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
827
828
829

        # adapter
        x = self.merger(x)
830

831
832
        return x

833
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
834
835
836
837
838
839
840
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
841
        loaded_params: set[str] = set()
842
843

        for name, loaded_weight in weights:
844
845
846
            if name.endswith("patch_embed.proj.weight"):
                loaded_weight = conv3d_to_linear_weight(loaded_weight)

847
            for param_name, weight_name, shard_id in stacked_params_mapping:
848
849
850
851
852
853
854
855
856
857
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
858
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
859
860
861
862
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

863

864
def _create_qwen2vl_field_factory(
865
    spatial_merge_size: int,
866
867
) -> Callable[
    [Mapping[str, torch.Tensor]],
868
    Mapping[str, MultiModalFieldConfig],
869
870
871
872
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
873
874
875
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
876
877
878

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
879
880
881
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
882
883
884

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
885
886
                "image", image_pixel_grid_sizes
            ),
887
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
888
889
                "image", image_embed_grid_sizes
            ),
890
891
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
892
893
                "video", video_grid_sizes
            ),
894
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
895
896
                "video", video_embed_grid_sizes
            ),
897
898
899
900
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
901

902

Roger Wang's avatar
Roger Wang committed
903
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
904
905
906
907
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

908
909
    def _parse_image_data(
        self,
910
911
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
912
        if isinstance(data, dict):
913
914
915
916
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
917
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
918
            )
919
920
921
922

        return super()._parse_image_data(data)

    def _parse_video_data(
923
        self,
924
925
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
926
        if isinstance(data, dict):
927
928
929
930
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
931
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
932
            )
933
934
935
936

        return super()._parse_video_data(data)


937
938
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
939
940
        return self.ctx.get_hf_config(Qwen2VLConfig)

941
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
942
943
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
944
            use_fast=kwargs.pop("use_fast", True),
945
946
947
            **kwargs,
        )

948
949
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
950

951
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
952
953
        return {"image": None, "video": None}

954
955
956
957
958
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
959
960
961
962
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

963
964
965
966
967
968
969
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
970
        image_processor: Qwen2VLImageProcessor | None,
971
    ) -> tuple[ImageSize, int]:
972
973
974
975
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
976
        vision_config = hf_config.vision_config
977
978
979
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
980

981
982
983
984
985
986
987
988
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
989
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
990
        else:
991
            preprocessed_size = ImageSize(width=image_width, height=image_height)
992

993
994
995
996
997
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
998
999
1000
1001
1002
1003
1004
1005
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

1006
    def get_num_image_tokens(
1007
1008
1009
1010
        self,
        *,
        image_width: int,
        image_height: int,
1011
        image_processor: Qwen2VLImageProcessor | None,
1012
1013
1014
1015
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
1016
            num_frames=1,
1017
            image_processor=image_processor,
1018
1019
1020
        )
        return num_image_tokens

1021
    def get_num_video_tokens(
1022
1023
1024
1025
1026
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1027
        image_processor: Qwen2VLImageProcessor | None,
1028
1029
1030
1031
1032
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1033
            image_processor=image_processor,
1034
1035
1036
        )
        return num_video_tokens

1037
    def get_image_size_with_most_features(self) -> ImageSize:
1038
1039
1040
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1041
            num_frames=1,
1042
            image_processor=None,
1043
1044
1045
        )
        return max_image_size

1046
1047
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1048

1049
        return self.get_num_image_tokens(
1050
1051
            image_width=target_width,
            image_height=target_height,
1052
            image_processor=None,
1053
        )
1054

1055
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1056
        target_width, target_height = self.get_image_size_with_most_features()
1057

1058
        num_frames = start_num_frames
1059
1060
1061

        while True:
            next_num_frames = num_frames + 1
1062
            next_max_tokens = self.get_num_video_tokens(
1063
1064
1065
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1066
                image_processor=None,
1067
            )
1068

1069
            if next_max_tokens > max_tokens:
1070
1071
1072
1073
1074
1075
                break

            num_frames = next_num_frames

        return num_frames

1076
1077
1078
1079
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1080
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1081
1082
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1083

1084
        max_total_frames = self._get_max_video_frames(seq_len)
1085
1086
1087
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1088

1089
        return max(max_frames_per_video, 1)
1090

1091
1092
1093
1094
1095
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1096
        target_width, target_height = self.get_image_size_with_most_features()
1097

1098
        return self.get_num_video_tokens(
1099
1100
            image_width=target_width,
            image_height=target_height,
1101
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1102
            image_processor=None,
1103
1104
        )

1105
1106

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1107
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1108
1109
1110
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1111
        hf_processor = self.info.get_hf_processor()
1112
1113
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1114

1115
1116
1117
1118
1119
1120
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1121
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1122
1123
1124
1125
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1126
1127
1128
1129
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1130

1131
1132
1133
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1134
        return {
1135
1136
1137
1138
1139
1140
1141
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1142
1143
                width=target_width,
                height=target_height,
1144
                num_frames=target_num_frames,
1145
                num_videos=num_videos,
1146
                overrides=video_overrides,
1147
            ),
1148
1149
        }

1150

1151
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1152
    def _get_data_parser(self) -> MultiModalDataParser:
1153
        return Qwen2VLMultiModalDataParser(
1154
1155
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1156

1157
    def _get_prompt_updates(
1158
1159
        self,
        mm_items: MultiModalDataItems,
1160
        hf_processor_mm_kwargs: Mapping[str, Any],
1161
        out_mm_kwargs: MultiModalKwargsItems,
1162
    ) -> Sequence[PromptUpdate]:
1163
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1164
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1165
1166
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1167
1168

        placeholder = {
1169
1170
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1171
        }
1172

1173
1174
1175
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1176
1177
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1178
1179
            assert isinstance(grid_thw, torch.Tensor)

1180
1181
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1182
1183
1184
1185

        return [
            PromptReplacement(
                modality=modality,
1186
                target=[placeholder[modality]],
1187
1188
1189
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1190
        ]
1191

1192
1193
1194
1195
1196
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1197
        return _create_qwen2vl_field_factory(
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1210
    # To ensure correct weight loading and mapping.
1211
1212
1213
1214
1215
1216
1217
1218
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1219
1220
        }
    )
1221

1222
1223
    supports_encoder_tp_data = True

1224
1225
1226
1227
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
1228
1229
1230
        image_grid_thw: list[list[int]] | torch.Tensor | None,
        video_grid_thw: list[list[int]] | torch.Tensor | None,
        second_per_grid_ts: list[float] | None = None,
1231
        context_len: int = 0,
1232
1233
        seq_len: int | None = None,
        audio_feature_lengths: torch.Tensor | None = None,
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1248
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1249
1250
1251

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1252
1253
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1301
1302
1303
1304
1305
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1306
1307
            text_len = ed - st

1308
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1309
            llm_pos_ids_list.append(
1310
1311
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1312

1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1324

1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1337
            llm_pos_ids_list.append(
1338
1339
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1340
1341
1342
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1343
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1344
1345
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1346
1347
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1348
1349

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1350
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1351
1352
1353
1354
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1355
    @classmethod
1356
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1357
1358
1359
1360
1361
1362
1363
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1364
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1365
        super().__init__()
1366
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1367
1368
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1369

1370
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1371
1372
1373
        self.config = config
        self.multimodal_config = multimodal_config

1374
1375
1376
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1377
1378
1379
1380
1381
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1382
1383
1384
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1385
                quant_config=quant_config,
1386
                prefix=maybe_prefix(prefix, "visual"),
1387
                use_data_parallel=self.use_data_parallel,
1388
                attn_backend_override=attn_backend_override,
1389
1390
1391
            )
        else:
            self.visual = None
1392

1393
1394
1395
1396
1397
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1398

1399
        self.make_empty_intermediate_tensors = (
1400
1401
            self.language_model.make_empty_intermediate_tensors
        )
1402

1403
1404
1405
    def _validate_and_reshape_mm_tensor(
        self, mm_input: object, name: str
    ) -> torch.Tensor:
1406
        if not isinstance(mm_input, (torch.Tensor, list)):
1407
            raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
1408
1409
1410
1411
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
1412
1413
1414
1415
1416
                raise ValueError(
                    f"{name} should be 2D or batched 3D tensor. "
                    f"Got ndim: {mm_input.ndim} "
                    f"(shape={mm_input.shape})"
                )
1417
            return mm_input.reshape(-1, mm_input.shape[-1])
1418
1419
1420
1421
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
1422
        self, **kwargs: object
1423
    ) -> Qwen2VLImageInputs | None:
1424
        pixel_values = kwargs.pop("pixel_values", None)
1425
        image_embeds = kwargs.pop("image_embeds", None)
1426
1427
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1428
        if pixel_values is None and image_embeds is None:
1429
1430
            return None

1431
1432
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
1433
1434
                pixel_values, "image pixel values"
            )
1435
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1436
1437
                image_grid_thw, "image grid_thw"
            )
1438

1439
1440
1441
1442
1443
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1444
1445

        if image_embeds is not None:
1446
            image_embeds = self._validate_and_reshape_mm_tensor(
1447
1448
                image_embeds, "image embeds"
            )
1449
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1450
1451
                image_grid_thw, "image grid_thw"
            )
1452

1453
1454
1455
1456
1457
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1458
1459

    def _parse_and_validate_video_input(
1460
        self, **kwargs: object
1461
    ) -> Qwen2VLVideoInputs | None:
1462
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1463
        video_embeds = kwargs.pop("video_embeds", None)
1464
1465
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1466
        if pixel_values_videos is None and video_embeds is None:
1467
1468
            return None

1469
1470
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
1471
1472
                pixel_values_videos, "video pixel values"
            )
1473
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1474
1475
                video_grid_thw, "video grid_thw"
            )
1476
1477
1478
1479
1480
1481
1482
1483
1484

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
1485
1486
                video_embeds, "video embeds"
            )
1487
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1488
1489
                video_grid_thw, "video grid_thw"
            )
1490

1491
1492
1493
1494
1495
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1496

1497
    def _process_image_input(
1498
1499
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1500
1501
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1502
        grid_thw_list = grid_thw.tolist()
1503

1504
        if image_input["type"] == "image_embeds":
1505
            image_embeds = image_input["image_embeds"]
1506
        else:
1507
            pixel_values = image_input["pixel_values"]
1508
1509

            if self.use_data_parallel:
1510
1511
1512
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
                )
1513
            else:
1514
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
1515
1516
1517

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1518
1519
1520
1521
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1522

1523
        return image_embeds.split(sizes)
1524
1525

    def _process_video_input(
1526
1527
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1528
1529
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1530
        grid_thw_list = grid_thw.tolist()
1531

1532
        if video_input["type"] == "video_embeds":
1533
            video_embeds = video_input["video_embeds"]
1534
        else:
1535
            pixel_values_videos = video_input["pixel_values_videos"]
1536
            if self.use_data_parallel:
1537
1538
1539
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1540
            else:
1541
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
1542

1543
1544
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1545
1546
1547
1548
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1549

1550
        return video_embeds.split(sizes)
1551
1552
1553
1554
1555
1556
1557

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1568
1569

        return modalities
1570

1571
1572
1573
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1574
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1575
1576
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1577
            return []
1578

1579
1580
1581
1582
1583
1584
1585
1586
1587
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1588
1589
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1590
1591
1592
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1593
                multimodal_embeddings += tuple(video_embeddings)
1594
1595
1596

        return multimodal_embeddings

1597
1598
1599
1600
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1601
1602
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1603
        **kwargs: object,
1604
    ) -> torch.Tensor | IntermediateTensors:
1605
1606
1607
1608
1609
1610
1611
1612
1613
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1614
1615
1616
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1617
        """
1618

1619
        if intermediate_tensors is not None:
1620
            inputs_embeds = None
1621

1622
        hidden_states = self.language_model.model(
1623
1624
            input_ids=input_ids,
            positions=positions,
1625
            intermediate_tensors=intermediate_tensors,
1626
1627
1628
1629
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1630
1631
1632
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1633
    ) -> torch.Tensor | None:
1634
        return self.language_model.compute_logits(hidden_states)
1635

1636
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1637
1638
1639
1640
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1641
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1642
1643
1644
1645
1646
1647
1648

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1649
1650
1651
            connector="visual.merger.",
            tower_model="visual.",
        )
1652
1653
1654
1655
1656
1657
1658
1659
1660


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1661
        size: dict[str, int] | None = None,
1662
1663
1664
1665
1666
1667
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1668
                "longest_edge": size["max_pixels"],
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1683
1684
1685
1686
1687
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1688
1689
            **kwargs,
        )
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1709
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1710
1711


1712
1713
1714
1715
1716
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1717
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1718
1719
1720
1721
1722
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1733
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1734
1735
1736
1737
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1738
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)