qwen2_vl.py 59 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31
32
33
34
35

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
36
from transformers import AutoConfig, BatchFeature, PretrainedConfig
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
38
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
39
40
41
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
42
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
43
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
44

45
from vllm.attention.backends.registry import _Backend
46
47
48
49
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
57
from vllm.model_executor.layers.quantization import QuantizationConfig
58
from vllm.model_executor.layers.rotary_embedding.common import (
59
60
    dispatch_rotary_emb_function,
)
61
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
from vllm.model_executor.models.module_mapping import MultiModelKeys
63
from vllm.multimodal import MULTIMODAL_REGISTRY
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
85
from vllm.multimodal.profiling import BaseDummyInputsBuilder
86
from vllm.sequence import IntermediateTensors
87
from vllm.transformers_utils.tokenizer import AnyTokenizer
88
from vllm.utils.tensor_schema import TensorSchema, TensorShape
89

90
91
92
93
94
95
96
97
98
99
100
101
102
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
103
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
104

105
106
logger = init_logger(__name__)

107
# For profile run
108
_MAX_FRAMES_PER_VIDEO = 14
109

110
111
112
# === Vision Inputs === #


113
class Qwen2VLImagePixelInputs(TensorSchema):
114
    """
115
116
117
118
119
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
120

121
    Historical context:
122
        - pixel_values shape: (num_patches, num_channels * patch_size *
123
124
125
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
126
    """
127

128
    type: Literal["pixel_values"]
129

130
131
132
133
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
134

135
136
137
138
139
140
141
142
143
144
145
146
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
147

148
149
150
151
152
153
154
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
155
    """
156

157
    type: Literal["image_embeds"]
158

159
160
161
162
163
164
165
166
167
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
168
169


170
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
171
172


173
174
175
176
177
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
178
        - ctps: Number of channels * temporal_patch_size * patch_size *
179
180
          patch_size
        - nv: Number of videos
181

182
    Historical context:
183
        - pixel_values_videos shape: (num_patches, num_channels *
184
185
186
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
187
    """
188

189
    type: Literal["pixel_values_videos"]
190

191
192
193
194
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
195

196
197
198
199
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
200
201


202
203
204
205
206
207
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
208

209
210
211
212
213
214
215
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
216
    """
217

218
    type: Literal["video_embeds"]
219

220
221
222
223
224
225
226
227
228
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
229
230


231
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
232

233
234
235
236
237
238
239
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
240
        hidden_features: int,
241
        act_layer: type[nn.Module] = QuickGELU,
242
        quant_config: QuantizationConfig | None = None,
243
        prefix: str = "",
244
        use_data_parallel: bool = False,
245
246
    ):
        super().__init__()
247
248
249
250
251
252
253
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
254
        self.act = act_layer()
255
256
257
258
259
260
261
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
276
277
278
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
279
280


281
282
283
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
284
285
286
287
288
289
290
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
291
292
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
293
    sin = repeat(
294
295
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
296
297
    return torch.cat(
        [
298
299
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
300
301
302
303
304
        ],
        dim=-1,
    )


305
306
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
307
308
309
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
310
    output = rotary_emb_function(t_, cos, sin).type_as(t)
311
312
313
314
315
316
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
317
318
319
        embed_dim: int,
        num_heads: int,
        projection_size: int,
320
        quant_config: QuantizationConfig | None = None,
321
        prefix: str = "",
322
        use_data_parallel: bool = False,
323
        attn_backend_override: _Backend | None = None,
324
325
326
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
327
328
329
330
331
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
332
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
333
        self.hidden_size_per_attention_head = dist_utils.divide(
334
335
            projection_size, num_heads
        )
336
        self.num_attention_heads_per_partition = dist_utils.divide(
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
354
355

        # Detect attention implementation.
356
357
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
358
            dtype=torch.get_default_dtype(),
359
            attn_backend_override=attn_backend_override,
360
        )
361
        self.use_upstream_fa = False
362

363
364
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
365
366
                self.attn_backend,
                self.use_upstream_fa,
367
                attn_backend_override=attn_backend_override,
368
            )
369
        )
370

371
        if self.attn_backend not in {
372
373
374
375
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
376
377
        }:
            raise RuntimeError(
378
379
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
380

381
        self.is_flash_attn_backend = self.attn_backend in {
382
383
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
384
        }
385

386
387
388
389
390
391
392
393
394
395
396
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
397
398
399
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
400
401
402
403
404
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
405
406
407
408
409
410
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
411
412
413
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

414
    def forward(
415
416
417
418
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
419
420
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
421
    ) -> torch.Tensor:
422
423
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
424

425
426
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
427
428
        batch_size = q.shape[1]

429
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
430
        if rotary_pos_emb is not None:
431
432
433
434
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
435

436
        if self.is_flash_attn_backend:
437
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
438

439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
454
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
455
456
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
457
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
458
459
460
461
462
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
463
464
465
466
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
467
468
469
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
470
471
472
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
473
        elif self.attn_backend == _Backend.XFORMERS:
474
475
476
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

477
478
479
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
480
481

            context_layer = xops.memory_efficient_attention_forward(
482
483
484
485
486
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
487
488
489
490
491
492
493
494
495
496
497

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
498
        act_layer: type[nn.Module] = QuickGELU,
499
500
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
501
        prefix: str = "",
502
        use_data_parallel: bool = False,
503
        attn_backend_override: _Backend | None = None,
504
505
506
507
508
509
510
511
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

512
513
514
515
516
517
518
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
519
            attn_backend_override=attn_backend_override,
520
521
522
523
524
525
526
527
528
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
529

530
    def forward(
531
532
533
534
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
535
536
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
537
538
539
540
541
542
543
544
545
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

546
547
548
549
550
551
552
553
554
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
555
        in_channels: int = 3,
556
557
558
559
560
561
562
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

563
        kernel_size = (temporal_patch_size, patch_size, patch_size)
564
565
566
567
568
569
570
        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )
571
572
573

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
574
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
575
576
577
578
579
580
581
582
583
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
584
        norm_layer: Callable[[int], nn.Module] | None = None,
585
        spatial_merge_size: int = 2,
586
        quant_config: QuantizationConfig | None = None,
587
        prefix: str = "",
588
        use_data_parallel: bool = False,
589
590
591
592
593
594
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
633
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
634
635
636
637
638
639
640
641
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
642
643
644
645
646
647
648
649
650
651
652
653
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
654
655
656
657
658
659
660
661
662
663
664
665
666
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
667
        quant_config: QuantizationConfig | None = None,
668
        prefix: str = "",
669
        use_data_parallel: bool = False,
670
        attn_backend_override: _Backend | None = None,
671
672
673
    ) -> None:
        super().__init__()

674
675
676
677
678
679
680
681
682
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
683

684
685
686
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

687
        self.spatial_merge_size = spatial_merge_size
688
689
        self.num_heads = num_heads
        self.embed_dim = embed_dim
690
691
692
693

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
694
            in_channels=in_channels,
695
696
697
698
699
700
701
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

702
703
704
705
706
707
708
709
710
711
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
712
                    attn_backend_override=attn_backend_override,
713
714
715
716
                )
                for layer_idx in range(depth)
            ]
        )
717
718
719
720
721
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
722
            prefix=f"{prefix}.merger",
723
            use_data_parallel=use_data_parallel,
724
        )
725
        self.attn_backend = get_vit_attn_backend(
726
727
728
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
729
730
731
732
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
733
            self.attn_backend = _Backend.FLASH_ATTN
734
735
736

    @property
    def dtype(self) -> torch.dtype:
737
        return self.patch_embed.proj.weight.dtype
738
739
740

    @property
    def device(self) -> torch.device:
741
        return self.patch_embed.proj.weight.device
742

743
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
744
        pos_ids = []
745
        max_grid_size = 0
746
747
748
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
770
            max_grid_size = max(max_grid_size, h, w)
771
772
773
774
775
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

776
    def compute_attn_mask_seqlen(
777
        self, cu_seqlens: torch.Tensor
778
    ) -> tuple[int | None, list[int] | None]:
779
        max_seqlen, seqlens = None, None
780
781
782
783
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
784
785
786
787
788
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

789
790
791
    def forward(
        self,
        x: torch.Tensor,
792
        grid_thw: list[list[int]],
793
794
795
796
797
798
799
800
801
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
802
        grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
803
804
805
        cu_seqlens = torch.repeat_interleave(
            grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
806
807
808
809
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
810

811
812
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
813
        for blk in self.blocks:
814
815
816
817
818
819
820
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
821
822
823

        # adapter
        x = self.merger(x)
824

825
826
        return x

827
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
828
829
830
831
832
833
834
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
835
        loaded_params: set[str] = set()
836
837

        for name, loaded_weight in weights:
838
            for param_name, weight_name, shard_id in stacked_params_mapping:
839
840
841
842
843
844
845
846
847
848
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
849
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
850
851
852
853
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

854

855
def _create_qwen2vl_field_factory(
856
    spatial_merge_size: int,
857
858
) -> Callable[
    [Mapping[str, torch.Tensor]],
859
    Mapping[str, MultiModalFieldConfig],
860
861
862
863
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
864
865
866
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
867
868
869

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
870
871
872
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
873
874
875

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
876
877
                "image", image_pixel_grid_sizes
            ),
878
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
879
880
                "image", image_embed_grid_sizes
            ),
881
882
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
883
884
                "video", video_grid_sizes
            ),
885
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
886
887
                "video", video_embed_grid_sizes
            ),
888
889
890
891
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
892

893

Roger Wang's avatar
Roger Wang committed
894
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
895
896
897
898
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

899
900
    def _parse_image_data(
        self,
901
902
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
903
        if isinstance(data, dict):
904
905
906
907
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
908
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
909
            )
910
911
912
913

        return super()._parse_image_data(data)

    def _parse_video_data(
914
        self,
915
916
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
917
        if isinstance(data, dict):
918
919
920
921
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
922
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
923
            )
924
925
926
927

        return super()._parse_video_data(data)


928
929
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
930
931
        return self.ctx.get_hf_config(Qwen2VLConfig)

932
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
933
934
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
935
            use_fast=kwargs.pop("use_fast", True),
936
937
938
            **kwargs,
        )

939
940
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
941

942
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
943
944
        return {"image": None, "video": None}

945
946
947
948
949
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
950
951
952
953
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

954
955
956
957
958
959
960
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
961
        image_processor: Qwen2VLImageProcessor | None,
962
    ) -> tuple[ImageSize, int]:
963
964
965
966
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
967
        vision_config = hf_config.vision_config
968
969
970
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
971

972
973
974
975
976
977
978
979
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
980
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
981
        else:
982
            preprocessed_size = ImageSize(width=image_width, height=image_height)
983

984
985
986
987
988
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
989
990
991
992
993
994
995
996
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

997
    def get_num_image_tokens(
998
999
1000
1001
        self,
        *,
        image_width: int,
        image_height: int,
1002
        image_processor: Qwen2VLImageProcessor | None,
1003
1004
1005
1006
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
1007
            num_frames=1,
1008
            image_processor=image_processor,
1009
1010
1011
        )
        return num_image_tokens

1012
    def get_num_video_tokens(
1013
1014
1015
1016
1017
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1018
        image_processor: Qwen2VLImageProcessor | None,
1019
1020
1021
1022
1023
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1024
            image_processor=image_processor,
1025
1026
1027
        )
        return num_video_tokens

1028
    def get_image_size_with_most_features(self) -> ImageSize:
1029
1030
1031
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1032
            num_frames=1,
1033
            image_processor=None,
1034
1035
1036
        )
        return max_image_size

1037
1038
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1039

1040
        return self.get_num_image_tokens(
1041
1042
            image_width=target_width,
            image_height=target_height,
1043
            image_processor=None,
1044
        )
1045

1046
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1047
        target_width, target_height = self.get_image_size_with_most_features()
1048

1049
        num_frames = start_num_frames
1050
1051
1052

        while True:
            next_num_frames = num_frames + 1
1053
            next_max_tokens = self.get_num_video_tokens(
1054
1055
1056
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1057
                image_processor=None,
1058
            )
1059

1060
            if next_max_tokens > max_tokens:
1061
1062
1063
1064
1065
1066
                break

            num_frames = next_num_frames

        return num_frames

1067
1068
1069
1070
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1071
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1072
1073
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1074

1075
        max_total_frames = self._get_max_video_frames(seq_len)
1076
1077
1078
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1079

1080
        return max(max_frames_per_video, 1)
1081

1082
1083
1084
1085
1086
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1087
        target_width, target_height = self.get_image_size_with_most_features()
1088

1089
        return self.get_num_video_tokens(
1090
1091
            image_width=target_width,
            image_height=target_height,
1092
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1093
            image_processor=None,
1094
1095
        )

1096
1097

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1098
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1099
1100
1101
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1102
        hf_processor = self.info.get_hf_processor()
1103
1104
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1105

1106
1107
1108
1109
1110
1111
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1112
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1113
1114
1115
1116
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1117
1118
1119
1120
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1121

1122
1123
1124
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1125
        return {
1126
1127
1128
1129
1130
1131
1132
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1133
1134
                width=target_width,
                height=target_height,
1135
                num_frames=target_num_frames,
1136
                num_videos=num_videos,
1137
                overrides=video_overrides,
1138
            ),
1139
1140
        }

1141

1142
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1143
    def _get_data_parser(self) -> MultiModalDataParser:
1144
        return Qwen2VLMultiModalDataParser(
1145
1146
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1147

1148
    def _get_prompt_updates(
1149
1150
        self,
        mm_items: MultiModalDataItems,
1151
        hf_processor_mm_kwargs: Mapping[str, Any],
1152
        out_mm_kwargs: MultiModalKwargsItems,
1153
    ) -> Sequence[PromptUpdate]:
1154
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1155
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1156
1157
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1158
1159

        placeholder = {
1160
1161
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1162
        }
1163

1164
1165
1166
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1167
1168
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1169
1170
            assert isinstance(grid_thw, torch.Tensor)

1171
1172
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1173
1174
1175
1176

        return [
            PromptReplacement(
                modality=modality,
1177
                target=[placeholder[modality]],
1178
1179
1180
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1181
        ]
1182

1183
1184
1185
1186
1187
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1188
        return _create_qwen2vl_field_factory(
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1201
    # To ensure correct weight loading and mapping.
1202
1203
1204
1205
1206
1207
1208
1209
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1210
1211
        }
    )
1212

1213
1214
    supports_encoder_tp_data = True

1215
1216
1217
1218
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
1219
1220
1221
        image_grid_thw: list[list[int]] | torch.Tensor | None,
        video_grid_thw: list[list[int]] | torch.Tensor | None,
        second_per_grid_ts: list[float] | None = None,
1222
        context_len: int = 0,
1223
1224
        seq_len: int | None = None,
        audio_feature_lengths: torch.Tensor | None = None,
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1239
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1240
1241
1242

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1243
1244
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1292
1293
1294
1295
1296
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1297
1298
            text_len = ed - st

1299
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1300
            llm_pos_ids_list.append(
1301
1302
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1303

1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1315

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1328
            llm_pos_ids_list.append(
1329
1330
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1331
1332
1333
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1334
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1335
1336
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1337
1338
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1339
1340

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1341
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1342
1343
1344
1345
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1346
    @classmethod
1347
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1348
1349
1350
1351
1352
1353
1354
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1355
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1356
        super().__init__()
1357
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1358
1359
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1360

1361
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1362
1363
1364
        self.config = config
        self.multimodal_config = multimodal_config

1365
1366
1367
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1368
1369
1370
1371
1372
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1373
1374
1375
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1376
                quant_config=quant_config,
1377
                prefix=maybe_prefix(prefix, "visual"),
1378
                use_data_parallel=self.use_data_parallel,
1379
                attn_backend_override=attn_backend_override,
1380
1381
1382
            )
        else:
            self.visual = None
1383

1384
1385
1386
1387
1388
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1389

1390
        self.make_empty_intermediate_tensors = (
1391
1392
            self.language_model.make_empty_intermediate_tensors
        )
1393

1394
1395
1396
    def _validate_and_reshape_mm_tensor(
        self, mm_input: object, name: str
    ) -> torch.Tensor:
1397
        if not isinstance(mm_input, (torch.Tensor, list)):
1398
            raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
1399
1400
1401
1402
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
1403
1404
1405
1406
1407
                raise ValueError(
                    f"{name} should be 2D or batched 3D tensor. "
                    f"Got ndim: {mm_input.ndim} "
                    f"(shape={mm_input.shape})"
                )
1408
            return mm_input.reshape(-1, mm_input.shape[-1])
1409
1410
1411
1412
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
1413
        self, **kwargs: object
1414
    ) -> Qwen2VLImageInputs | None:
1415
        pixel_values = kwargs.pop("pixel_values", None)
1416
        image_embeds = kwargs.pop("image_embeds", None)
1417
1418
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1419
        if pixel_values is None and image_embeds is None:
1420
1421
            return None

1422
1423
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
1424
1425
                pixel_values, "image pixel values"
            )
1426
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1427
1428
                image_grid_thw, "image grid_thw"
            )
1429

1430
1431
1432
1433
1434
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1435
1436

        if image_embeds is not None:
1437
            image_embeds = self._validate_and_reshape_mm_tensor(
1438
1439
                image_embeds, "image embeds"
            )
1440
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1441
1442
                image_grid_thw, "image grid_thw"
            )
1443

1444
1445
1446
1447
1448
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1449
1450

    def _parse_and_validate_video_input(
1451
        self, **kwargs: object
1452
    ) -> Qwen2VLVideoInputs | None:
1453
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1454
        video_embeds = kwargs.pop("video_embeds", None)
1455
1456
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1457
        if pixel_values_videos is None and video_embeds is None:
1458
1459
            return None

1460
1461
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
1462
1463
                pixel_values_videos, "video pixel values"
            )
1464
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1465
1466
                video_grid_thw, "video grid_thw"
            )
1467
1468
1469
1470
1471
1472
1473
1474
1475

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
1476
1477
                video_embeds, "video embeds"
            )
1478
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1479
1480
                video_grid_thw, "video grid_thw"
            )
1481

1482
1483
1484
1485
1486
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1487

1488
    def _process_image_input(
1489
1490
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1491
1492
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1493
        grid_thw_list = grid_thw.tolist()
1494

1495
        if image_input["type"] == "image_embeds":
1496
            image_embeds = image_input["image_embeds"]
1497
        else:
1498
            pixel_values = image_input["pixel_values"]
1499
1500

            if self.use_data_parallel:
1501
1502
1503
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
                )
1504
            else:
1505
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
1506
1507
1508

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1509
1510
1511
1512
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1513

1514
        return image_embeds.split(sizes)
1515
1516

    def _process_video_input(
1517
1518
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1519
1520
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1521
        grid_thw_list = grid_thw.tolist()
1522

1523
        if video_input["type"] == "video_embeds":
1524
            video_embeds = video_input["video_embeds"]
1525
        else:
1526
            pixel_values_videos = video_input["pixel_values_videos"]
1527
            if self.use_data_parallel:
1528
1529
1530
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1531
            else:
1532
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
1533

1534
1535
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1536
1537
1538
1539
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1540

1541
        return video_embeds.split(sizes)
1542
1543
1544
1545
1546
1547
1548

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1559
1560

        return modalities
1561

1562
1563
1564
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1565
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1566
1567
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1568
            return []
1569

1570
1571
1572
1573
1574
1575
1576
1577
1578
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1579
1580
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1581
1582
1583
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1584
                multimodal_embeddings += tuple(video_embeddings)
1585
1586
1587

        return multimodal_embeddings

1588
1589
1590
1591
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1592
1593
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1594
        **kwargs: object,
1595
    ) -> torch.Tensor | IntermediateTensors:
1596
1597
1598
1599
1600
1601
1602
1603
1604
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1605
1606
1607
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1608
        """
1609

1610
        if intermediate_tensors is not None:
1611
            inputs_embeds = None
1612

1613
        hidden_states = self.language_model.model(
1614
1615
            input_ids=input_ids,
            positions=positions,
1616
            intermediate_tensors=intermediate_tensors,
1617
1618
1619
1620
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1621
1622
1623
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1624
    ) -> torch.Tensor | None:
1625
        return self.language_model.compute_logits(hidden_states)
1626

1627
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1628
1629
1630
1631
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1632
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1633
1634
1635
1636
1637
1638
1639

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1640
1641
1642
            connector="visual.merger.",
            tower_model="visual.",
        )
1643
1644
1645
1646
1647
1648
1649
1650
1651


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1652
        size: dict[str, int] | None = None,
1653
1654
1655
1656
1657
1658
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1659
                "longest_edge": size["max_pixels"],
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1674
1675
1676
1677
1678
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1679
1680
            **kwargs,
        )
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1700
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1701
1702


1703
1704
1705
1706
1707
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1708
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1709
1710
1711
1712
1713
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1724
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1725
1726
1727
1728
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1729
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)