qwen2_vl.py 59 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31
32
33
34
35

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
36
from transformers import AutoConfig, BatchFeature, PretrainedConfig
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
38
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
39
40
41
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
42
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
43
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
44

45
from vllm.attention.backends.registry import _Backend
46
47
48
49
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
57
from vllm.model_executor.layers.quantization import QuantizationConfig
58
from vllm.model_executor.layers.rotary_embedding.common import (
59
60
    dispatch_rotary_emb_function,
)
61
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
from vllm.model_executor.models.module_mapping import MultiModelKeys
63
from vllm.multimodal import MULTIMODAL_REGISTRY
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
85
from vllm.multimodal.profiling import BaseDummyInputsBuilder
86
from vllm.sequence import IntermediateTensors
87
from vllm.transformers_utils.tokenizer import AnyTokenizer
88
from vllm.utils.tensor_schema import TensorSchema, TensorShape
89

90
91
92
93
94
95
96
97
98
99
100
101
102
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
103
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
104

105
106
logger = init_logger(__name__)

107
# For profile run
108
_MAX_FRAMES_PER_VIDEO = 14
109

110
111
112
# === Vision Inputs === #


113
class Qwen2VLImagePixelInputs(TensorSchema):
114
    """
115
116
117
118
119
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
120

121
    Historical context:
122
        - pixel_values shape: (num_patches, num_channels * patch_size *
123
124
125
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
126
    """
127

128
    type: Literal["pixel_values"]
129

130
131
132
133
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
134

135
136
137
138
139
140
141
142
143
144
145
146
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
147

148
149
150
151
152
153
154
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
155
    """
156

157
    type: Literal["image_embeds"]
158

159
160
161
162
163
164
165
166
167
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
168
169


170
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
171
172


173
174
175
176
177
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
178
        - ctps: Number of channels * temporal_patch_size * patch_size *
179
180
          patch_size
        - nv: Number of videos
181

182
    Historical context:
183
        - pixel_values_videos shape: (num_patches, num_channels *
184
185
186
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
187
    """
188

189
    type: Literal["pixel_values_videos"]
190

191
192
193
194
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
195

196
197
198
199
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
200
201


202
203
204
205
206
207
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
208

209
210
211
212
213
214
215
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
216
    """
217

218
    type: Literal["video_embeds"]
219

220
221
222
223
224
225
226
227
228
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
229
230


231
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
232

233
234
235
236
237
238
239
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
240
        hidden_features: int,
241
        act_layer: type[nn.Module] = QuickGELU,
242
        quant_config: QuantizationConfig | None = None,
243
        prefix: str = "",
244
        use_data_parallel: bool = False,
245
246
    ):
        super().__init__()
247
248
249
250
251
252
253
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
254
        self.act = act_layer()
255
256
257
258
259
260
261
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
276
277
278
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
279
280


281
282
283
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
284
285
286
287
288
289
290
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
291
292
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
293
    sin = repeat(
294
295
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
296
297
    return torch.cat(
        [
298
299
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
300
301
302
303
304
        ],
        dim=-1,
    )


305
306
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
307
308
309
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
310
    output = rotary_emb_function(t_, cos, sin).type_as(t)
311
312
313
314
315
316
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
317
318
319
        embed_dim: int,
        num_heads: int,
        projection_size: int,
320
        quant_config: QuantizationConfig | None = None,
321
        prefix: str = "",
322
        use_data_parallel: bool = False,
323
        attn_backend_override: _Backend | None = None,
324
325
326
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
327
328
329
330
331
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
332
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
333
        self.hidden_size_per_attention_head = dist_utils.divide(
334
335
            projection_size, num_heads
        )
336
        self.num_attention_heads_per_partition = dist_utils.divide(
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
354
355

        # Detect attention implementation.
356
357
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
358
            dtype=torch.get_default_dtype(),
359
            attn_backend_override=attn_backend_override,
360
        )
361
        self.use_upstream_fa = False
362

363
364
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
365
366
367
                self.attn_backend,
                self.use_upstream_fa,
            )
368
        )
369

370
        if self.attn_backend not in {
371
372
373
374
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
375
376
        }:
            raise RuntimeError(
377
378
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
379

380
        self.is_flash_attn_backend = self.attn_backend in {
381
382
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
383
        }
384

385
386
387
388
389
390
391
392
393
394
395
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
396
397
398
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
399
400
401
402
403
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
404
405
406
407
408
409
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
410
411
412
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

413
    def forward(
414
415
416
417
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
418
419
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
420
    ) -> torch.Tensor:
421
422
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
423

424
425
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
426
427
        batch_size = q.shape[1]

428
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
429
        if rotary_pos_emb is not None:
430
431
432
433
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
434

435
        if self.is_flash_attn_backend:
436
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
453
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
454
455
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
456
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
457
458
459
460
461
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
462
463
464
465
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
466
467
468
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
469
470
471
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
472
        elif self.attn_backend == _Backend.XFORMERS:
473
474
475
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

476
477
478
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
479
480

            context_layer = xops.memory_efficient_attention_forward(
481
482
483
484
485
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
486
487
488
489
490
491
492
493
494
495
496

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
497
        act_layer: type[nn.Module] = QuickGELU,
498
499
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
500
        prefix: str = "",
501
        use_data_parallel: bool = False,
502
        attn_backend_override: _Backend | None = None,
503
504
505
506
507
508
509
510
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

511
512
513
514
515
516
517
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
518
            attn_backend_override=attn_backend_override,
519
520
521
522
523
524
525
526
527
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
528

529
    def forward(
530
531
532
533
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
534
535
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
536
537
538
539
540
541
542
543
544
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

545
546
547
548
549
550
551
552
553
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
554
        in_channels: int = 3,
555
556
557
558
559
560
561
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

562
        kernel_size = (temporal_patch_size, patch_size, patch_size)
563
564
565
566
567
568
569
        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )
570
571
572

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
573
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
574
575
576
577
578
579
580
581
582
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
583
        norm_layer: Callable[[int], nn.Module] | None = None,
584
        spatial_merge_size: int = 2,
585
        quant_config: QuantizationConfig | None = None,
586
        prefix: str = "",
587
        use_data_parallel: bool = False,
588
589
590
591
592
593
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
632
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
633
634
635
636
637
638
639
640
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
641
642
643
644
645
646
647
648
649
650
651
652
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
653
654
655
656
657
658
659
660
661
662
663
664
665
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
666
        quant_config: QuantizationConfig | None = None,
667
        prefix: str = "",
668
        use_data_parallel: bool = False,
669
        attn_backend_override: _Backend | None = None,
670
671
672
    ) -> None:
        super().__init__()

673
674
675
676
677
678
679
680
681
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
682

683
684
685
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

686
        self.spatial_merge_size = spatial_merge_size
687
688
        self.num_heads = num_heads
        self.embed_dim = embed_dim
689
690
691
692

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
693
            in_channels=in_channels,
694
695
696
697
698
699
700
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

701
702
703
704
705
706
707
708
709
710
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
711
                    attn_backend_override=attn_backend_override,
712
713
714
715
                )
                for layer_idx in range(depth)
            ]
        )
716
717
718
719
720
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
721
            prefix=f"{prefix}.merger",
722
            use_data_parallel=use_data_parallel,
723
        )
724
        self.attn_backend = get_vit_attn_backend(
725
726
727
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
728
729
730
731
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
732
            self.attn_backend = _Backend.FLASH_ATTN
733
734
735

    @property
    def dtype(self) -> torch.dtype:
736
        return self.patch_embed.proj.weight.dtype
737
738
739

    @property
    def device(self) -> torch.device:
740
        return self.patch_embed.proj.weight.device
741

742
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
743
        pos_ids = []
744
        max_grid_size = 0
745
746
747
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
769
            max_grid_size = max(max_grid_size, h, w)
770
771
772
773
774
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

775
    def compute_attn_mask_seqlen(
776
        self, cu_seqlens: torch.Tensor
777
    ) -> tuple[int | None, list[int] | None]:
778
        max_seqlen, seqlens = None, None
779
780
781
782
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
783
784
785
786
787
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

788
789
790
    def forward(
        self,
        x: torch.Tensor,
791
        grid_thw: list[list[int]],
792
793
794
795
796
797
798
799
800
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
801
        grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
802
803
804
        cu_seqlens = torch.repeat_interleave(
            grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
805
806
807
808
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
809

810
811
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
812
        for blk in self.blocks:
813
814
815
816
817
818
819
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
820
821
822

        # adapter
        x = self.merger(x)
823

824
825
        return x

826
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
827
828
829
830
831
832
833
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
834
        loaded_params: set[str] = set()
835
836

        for name, loaded_weight in weights:
837
            for param_name, weight_name, shard_id in stacked_params_mapping:
838
839
840
841
842
843
844
845
846
847
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
848
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
849
850
851
852
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

853

854
def _create_qwen2vl_field_factory(
855
    spatial_merge_size: int,
856
857
) -> Callable[
    [Mapping[str, torch.Tensor]],
858
    Mapping[str, MultiModalFieldConfig],
859
860
861
862
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
863
864
865
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
866
867
868

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
869
870
871
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
872
873
874

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
875
876
                "image", image_pixel_grid_sizes
            ),
877
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
878
879
                "image", image_embed_grid_sizes
            ),
880
881
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
882
883
                "video", video_grid_sizes
            ),
884
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
885
886
                "video", video_embed_grid_sizes
            ),
887
888
889
890
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
891

892

Roger Wang's avatar
Roger Wang committed
893
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
894
895
896
897
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

898
899
    def _parse_image_data(
        self,
900
901
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
902
        if isinstance(data, dict):
903
904
905
906
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
907
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
908
            )
909
910
911
912

        return super()._parse_image_data(data)

    def _parse_video_data(
913
        self,
914
915
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
916
        if isinstance(data, dict):
917
918
919
920
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
921
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
922
            )
923
924
925
926

        return super()._parse_video_data(data)


927
928
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
929
930
        return self.ctx.get_hf_config(Qwen2VLConfig)

931
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
932
933
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
934
            use_fast=kwargs.pop("use_fast", True),
935
936
937
            **kwargs,
        )

938
939
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
940

941
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
942
943
        return {"image": None, "video": None}

944
945
946
947
948
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
949
950
951
952
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

953
954
955
956
957
958
959
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
960
        image_processor: Qwen2VLImageProcessor | None,
961
    ) -> tuple[ImageSize, int]:
962
963
964
965
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
966
        vision_config = hf_config.vision_config
967
968
969
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
970

971
972
973
974
975
976
977
978
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
979
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
980
        else:
981
            preprocessed_size = ImageSize(width=image_width, height=image_height)
982

983
984
985
986
987
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
988
989
990
991
992
993
994
995
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

996
    def get_num_image_tokens(
997
998
999
1000
        self,
        *,
        image_width: int,
        image_height: int,
1001
        image_processor: Qwen2VLImageProcessor | None,
1002
1003
1004
1005
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
1006
            num_frames=1,
1007
            image_processor=image_processor,
1008
1009
1010
        )
        return num_image_tokens

1011
    def get_num_video_tokens(
1012
1013
1014
1015
1016
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1017
        image_processor: Qwen2VLImageProcessor | None,
1018
1019
1020
1021
1022
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1023
            image_processor=image_processor,
1024
1025
1026
        )
        return num_video_tokens

1027
    def get_image_size_with_most_features(self) -> ImageSize:
1028
1029
1030
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1031
            num_frames=1,
1032
            image_processor=None,
1033
1034
1035
        )
        return max_image_size

1036
1037
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1038

1039
        return self.get_num_image_tokens(
1040
1041
            image_width=target_width,
            image_height=target_height,
1042
            image_processor=None,
1043
        )
1044

1045
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1046
        target_width, target_height = self.get_image_size_with_most_features()
1047

1048
        num_frames = start_num_frames
1049
1050
1051

        while True:
            next_num_frames = num_frames + 1
1052
            next_max_tokens = self.get_num_video_tokens(
1053
1054
1055
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1056
                image_processor=None,
1057
            )
1058

1059
            if next_max_tokens > max_tokens:
1060
1061
1062
1063
1064
1065
                break

            num_frames = next_num_frames

        return num_frames

1066
1067
1068
1069
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1070
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1071
1072
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1073

1074
        max_total_frames = self._get_max_video_frames(seq_len)
1075
1076
1077
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1078

1079
        return max(max_frames_per_video, 1)
1080

1081
1082
1083
1084
1085
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1086
        target_width, target_height = self.get_image_size_with_most_features()
1087

1088
        return self.get_num_video_tokens(
1089
1090
            image_width=target_width,
            image_height=target_height,
1091
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1092
            image_processor=None,
1093
1094
        )

1095
1096

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1097
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1098
1099
1100
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1101
        hf_processor = self.info.get_hf_processor()
1102
1103
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1104

1105
1106
1107
1108
1109
1110
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1111
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1112
1113
1114
1115
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1116
1117
1118
1119
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1120

1121
1122
1123
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1124
        return {
1125
1126
1127
1128
1129
1130
1131
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1132
1133
                width=target_width,
                height=target_height,
1134
                num_frames=target_num_frames,
1135
                num_videos=num_videos,
1136
                overrides=video_overrides,
1137
            ),
1138
1139
        }

1140

1141
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1142
    def _get_data_parser(self) -> MultiModalDataParser:
1143
        return Qwen2VLMultiModalDataParser(
1144
1145
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1146

1147
    def _get_prompt_updates(
1148
1149
        self,
        mm_items: MultiModalDataItems,
1150
        hf_processor_mm_kwargs: Mapping[str, Any],
1151
        out_mm_kwargs: MultiModalKwargsItems,
1152
    ) -> Sequence[PromptUpdate]:
1153
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1154
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1155
1156
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1157
1158

        placeholder = {
1159
1160
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1161
        }
1162

1163
1164
1165
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1166
1167
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1168
1169
            assert isinstance(grid_thw, torch.Tensor)

1170
1171
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1172
1173
1174
1175

        return [
            PromptReplacement(
                modality=modality,
1176
                target=[placeholder[modality]],
1177
1178
1179
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1180
        ]
1181

1182
1183
1184
1185
1186
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1187
        return _create_qwen2vl_field_factory(
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1200
    # To ensure correct weight loading and mapping.
1201
1202
1203
1204
1205
1206
1207
1208
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1209
1210
        }
    )
1211

1212
1213
    supports_encoder_tp_data = True

1214
1215
1216
1217
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
1218
1219
1220
        image_grid_thw: list[list[int]] | torch.Tensor | None,
        video_grid_thw: list[list[int]] | torch.Tensor | None,
        second_per_grid_ts: list[float] | None = None,
1221
        context_len: int = 0,
1222
1223
        seq_len: int | None = None,
        audio_feature_lengths: torch.Tensor | None = None,
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1238
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1239
1240
1241

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1242
1243
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1291
1292
1293
1294
1295
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1296
1297
            text_len = ed - st

1298
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1299
            llm_pos_ids_list.append(
1300
1301
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1302

1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1314

1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1327
            llm_pos_ids_list.append(
1328
1329
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1330
1331
1332
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1333
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1334
1335
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1336
1337
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1338
1339

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1340
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1341
1342
1343
1344
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1345
    @classmethod
1346
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1347
1348
1349
1350
1351
1352
1353
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1354
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1355
        super().__init__()
1356
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1357
1358
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1359

1360
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1361
1362
1363
        self.config = config
        self.multimodal_config = multimodal_config

1364
1365
1366
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1367
1368
1369
1370
1371
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1372
1373
1374
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1375
                quant_config=quant_config,
1376
                prefix=maybe_prefix(prefix, "visual"),
1377
                use_data_parallel=self.use_data_parallel,
1378
                attn_backend_override=attn_backend_override,
1379
1380
1381
            )
        else:
            self.visual = None
1382

1383
1384
1385
1386
1387
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1388

1389
        self.make_empty_intermediate_tensors = (
1390
1391
            self.language_model.make_empty_intermediate_tensors
        )
1392

1393
1394
1395
    def _validate_and_reshape_mm_tensor(
        self, mm_input: object, name: str
    ) -> torch.Tensor:
1396
        if not isinstance(mm_input, (torch.Tensor, list)):
1397
            raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
1398
1399
1400
1401
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
1402
1403
1404
1405
1406
                raise ValueError(
                    f"{name} should be 2D or batched 3D tensor. "
                    f"Got ndim: {mm_input.ndim} "
                    f"(shape={mm_input.shape})"
                )
1407
            return mm_input.reshape(-1, mm_input.shape[-1])
1408
1409
1410
1411
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
1412
        self, **kwargs: object
1413
    ) -> Qwen2VLImageInputs | None:
1414
        pixel_values = kwargs.pop("pixel_values", None)
1415
        image_embeds = kwargs.pop("image_embeds", None)
1416
1417
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1418
        if pixel_values is None and image_embeds is None:
1419
1420
            return None

1421
1422
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
1423
1424
                pixel_values, "image pixel values"
            )
1425
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1426
1427
                image_grid_thw, "image grid_thw"
            )
1428

1429
1430
1431
1432
1433
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1434
1435

        if image_embeds is not None:
1436
            image_embeds = self._validate_and_reshape_mm_tensor(
1437
1438
                image_embeds, "image embeds"
            )
1439
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1440
1441
                image_grid_thw, "image grid_thw"
            )
1442

1443
1444
1445
1446
1447
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1448
1449

    def _parse_and_validate_video_input(
1450
        self, **kwargs: object
1451
    ) -> Qwen2VLVideoInputs | None:
1452
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1453
        video_embeds = kwargs.pop("video_embeds", None)
1454
1455
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1456
        if pixel_values_videos is None and video_embeds is None:
1457
1458
            return None

1459
1460
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
1461
1462
                pixel_values_videos, "video pixel values"
            )
1463
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1464
1465
                video_grid_thw, "video grid_thw"
            )
1466
1467
1468
1469
1470
1471
1472
1473
1474

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
1475
1476
                video_embeds, "video embeds"
            )
1477
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1478
1479
                video_grid_thw, "video grid_thw"
            )
1480

1481
1482
1483
1484
1485
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1486

1487
    def _process_image_input(
1488
1489
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1490
1491
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1492
        grid_thw_list = grid_thw.tolist()
1493

1494
        if image_input["type"] == "image_embeds":
1495
            image_embeds = image_input["image_embeds"]
1496
        else:
1497
            pixel_values = image_input["pixel_values"]
1498
1499

            if self.use_data_parallel:
1500
1501
1502
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
                )
1503
            else:
1504
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
1505
1506
1507

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1508
1509
1510
1511
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1512

1513
        return image_embeds.split(sizes)
1514
1515

    def _process_video_input(
1516
1517
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1518
1519
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1520
        grid_thw_list = grid_thw.tolist()
1521

1522
        if video_input["type"] == "video_embeds":
1523
            video_embeds = video_input["video_embeds"]
1524
        else:
1525
            pixel_values_videos = video_input["pixel_values_videos"]
1526
            if self.use_data_parallel:
1527
1528
1529
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1530
            else:
1531
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
1532

1533
1534
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1535
1536
1537
1538
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1539

1540
        return video_embeds.split(sizes)
1541
1542
1543
1544
1545
1546
1547

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1558
1559

        return modalities
1560

1561
1562
1563
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1564
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1565
1566
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1567
            return []
1568

1569
1570
1571
1572
1573
1574
1575
1576
1577
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1578
1579
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1580
1581
1582
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1583
                multimodal_embeddings += tuple(video_embeddings)
1584
1585
1586

        return multimodal_embeddings

1587
1588
1589
1590
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1591
1592
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1593
        **kwargs: object,
1594
    ) -> torch.Tensor | IntermediateTensors:
1595
1596
1597
1598
1599
1600
1601
1602
1603
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1604
1605
1606
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1607
        """
1608

1609
        if intermediate_tensors is not None:
1610
            inputs_embeds = None
1611

1612
        hidden_states = self.language_model.model(
1613
1614
            input_ids=input_ids,
            positions=positions,
1615
            intermediate_tensors=intermediate_tensors,
1616
1617
1618
1619
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1620
1621
1622
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1623
    ) -> torch.Tensor | None:
1624
        return self.language_model.compute_logits(hidden_states)
1625

1626
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1627
1628
1629
1630
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1631
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1632
1633
1634
1635
1636
1637
1638

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1639
1640
1641
            connector="visual.merger.",
            tower_model="visual.",
        )
1642
1643
1644
1645
1646
1647
1648
1649
1650


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1651
        size: dict[str, int] | None = None,
1652
1653
1654
1655
1656
1657
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1658
                "longest_edge": size["max_pixels"],
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1673
1674
1675
1676
1677
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1678
1679
            **kwargs,
        )
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1699
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1700
1701


1702
1703
1704
1705
1706
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1707
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1708
1709
1710
1711
1712
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1723
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1724
1725
1726
1727
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1728
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)