qwen2_vl.py 57.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
import math
29
from collections.abc import Callable, Iterable, Mapping, Sequence
30
from functools import partial
31
from typing import Annotated, Any, Literal, TypeAlias
32
33
34
35
36

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
37
from transformers import BatchFeature, PretrainedConfig
38
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
40
41
42
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
43
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
44
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
45

46
from vllm.attention.backends.registry import _Backend
47
48
49
50
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
51
from vllm.config import VllmConfig
52
from vllm.config.multimodal import BaseDummyOptions
53
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
54
55
56
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
57
58
59
60
61
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.rotary_embedding.common import (
64
65
    dispatch_rotary_emb_function,
)
66
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
67
from vllm.model_executor.models.module_mapping import MultiModelKeys
68
from vllm.multimodal import MULTIMODAL_REGISTRY
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
90
from vllm.multimodal.profiling import BaseDummyInputsBuilder
91
from vllm.sequence import IntermediateTensors
92
from vllm.transformers_utils.tokenizer import AnyTokenizer
93
from vllm.utils.tensor_schema import TensorSchema, TensorShape
94

95
96
97
98
99
100
101
102
103
104
105
106
107
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
108
109
110
111
112
from .vision import (
    conv3d_to_linear_weight,
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
113

114
115
logger = init_logger(__name__)

116
# For profile run
117
_MAX_FRAMES_PER_VIDEO = 14
118

119
120
121
# === Vision Inputs === #


122
class Qwen2VLImagePixelInputs(TensorSchema):
123
    """
124
125
126
127
128
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
129

130
    Historical context:
131
        - pixel_values shape: (num_patches, num_channels * patch_size *
132
133
134
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
135
    """
136

137
    type: Literal["pixel_values"]
138

139
140
141
142
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
143

144
145
146
147
148
149
150
151
152
153
154
155
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
156

157
158
159
160
161
162
163
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
164
    """
165

166
    type: Literal["image_embeds"]
167

168
169
170
171
172
173
174
175
176
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
177
178


179
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
180
181


182
183
184
185
186
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
187
        - ctps: Number of channels * temporal_patch_size * patch_size *
188
189
          patch_size
        - nv: Number of videos
190

191
    Historical context:
192
        - pixel_values_videos shape: (num_patches, num_channels *
193
194
195
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
196
    """
197

198
    type: Literal["pixel_values_videos"]
199

200
201
202
203
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
204

205
206
207
208
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
209
210


211
212
213
214
215
216
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
217

218
219
220
221
222
223
224
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
225
    """
226

227
    type: Literal["video_embeds"]
228

229
230
231
232
233
234
235
236
237
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
238
239


240
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
241

242
243
244
245
246
247
248
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
249
        hidden_features: int,
250
        act_layer: type[nn.Module] = QuickGELU,
251
        quant_config: QuantizationConfig | None = None,
252
        prefix: str = "",
253
        use_data_parallel: bool = False,
254
255
    ):
        super().__init__()
256
257
258
259
260
261
262
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
263
        self.act = act_layer()
264
265
266
267
268
269
270
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
285
286
287
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
288
289


290
291
292
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
293
294
295
296
297
298
299
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
300
301
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
302
    sin = repeat(
303
304
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
305
306
    return torch.cat(
        [
307
308
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
309
310
311
312
313
        ],
        dim=-1,
    )


314
315
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
316
317
318
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
319
    output = rotary_emb_function(t_, cos, sin).type_as(t)
320
321
322
323
324
325
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
326
327
328
        embed_dim: int,
        num_heads: int,
        projection_size: int,
329
        quant_config: QuantizationConfig | None = None,
330
        prefix: str = "",
331
        use_data_parallel: bool = False,
332
        attn_backend_override: _Backend | None = None,
333
334
335
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
336
337
338
339
340
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
341
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
342
        self.hidden_size_per_attention_head = dist_utils.divide(
343
344
            projection_size, num_heads
        )
345
        self.num_attention_heads_per_partition = dist_utils.divide(
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
363
364

        # Detect attention implementation.
365
366
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
367
            dtype=torch.get_default_dtype(),
368
            attn_backend_override=attn_backend_override,
369
        )
370
        self.use_upstream_fa = False
371

372
373
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
374
375
                self.attn_backend,
                self.use_upstream_fa,
376
                attn_backend_override=attn_backend_override,
377
            )
378
        )
379

380
        if self.attn_backend not in {
381
382
383
384
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
385
386
        }:
            raise RuntimeError(
387
388
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
389

390
        self.is_flash_attn_backend = self.attn_backend in {
391
392
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
393
        }
394

395
396
397
398
399
400
401
402
403
404
405
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
406
407
408
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
409
410
411
412
413
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
414
415
416
417
418
419
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
420
421
422
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

423
    def forward(
424
425
426
427
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
428
429
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
430
    ) -> torch.Tensor:
431
432
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
433

434
435
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
436
437
        batch_size = q.shape[1]

438
        q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
439
        if rotary_pos_emb is not None:
440
441
442
443
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
444

445
        if self.is_flash_attn_backend:
446
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
447

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
463
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
464
            # Execute attention entry by entry for speed & less VRAM.
465
466
467
468
469
470
            from vllm.platforms import current_platform

            if current_platform.is_rocm():
                q = q.contiguous()
                k = k.contiguous()
                v = v.contiguous()
燃's avatar
committed
471
            outputs = []
472
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
473
474
475
476
477
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
478
479
480
481
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
482
483
484
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
485
486
487
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
488
        elif self.attn_backend == _Backend.XFORMERS:
489
490
491
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

492
493
494
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
495
496

            context_layer = xops.memory_efficient_attention_forward(
497
498
499
500
501
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
502
503
504
505
506
507
508
509
510
511
512

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
513
        act_layer: type[nn.Module] = QuickGELU,
514
515
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
516
        prefix: str = "",
517
        use_data_parallel: bool = False,
518
        attn_backend_override: _Backend | None = None,
519
520
521
522
523
524
525
526
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

527
528
529
530
531
532
533
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
534
            attn_backend_override=attn_backend_override,
535
536
537
538
539
540
541
542
543
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
544

545
    def forward(
546
547
548
549
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
550
551
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
552
553
554
555
556
557
558
559
560
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

561
562
563
564
565
566
567
568
569
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
570
        in_channels: int = 3,
571
572
573
574
575
576
577
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

578
        kernel_size = (temporal_patch_size, patch_size, patch_size)
579
580
        self.proj = ReplicatedLinear(
            in_channels * math.prod(kernel_size),
581
582
            embed_dim,
            bias=False,
583
            return_bias=False,
584
        )
585
586

    def forward(self, x: torch.Tensor) -> torch.Tensor:
587
        x = self.proj(x)
588
589
590
591
592
593
594
595
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
596
        norm_layer: Callable[[int], nn.Module] | None = None,
597
        spatial_merge_size: int = 2,
598
        quant_config: QuantizationConfig | None = None,
599
        prefix: str = "",
600
        use_data_parallel: bool = False,
601
602
603
604
605
606
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
645
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
646
647
648
649
650
651
652
653
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
654
655
656
657
658
659
660
661
662
663
664
665
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
666
667
668
669
670
671
672
673
674
675
676
677
678
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
679
        quant_config: QuantizationConfig | None = None,
680
        prefix: str = "",
681
        use_data_parallel: bool = False,
682
        attn_backend_override: _Backend | None = None,
683
684
685
    ) -> None:
        super().__init__()

686
687
688
689
690
691
692
693
694
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
695

696
697
698
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

699
        self.spatial_merge_size = spatial_merge_size
700
701
        self.num_heads = num_heads
        self.embed_dim = embed_dim
702
703
704
705

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
706
            in_channels=in_channels,
707
708
709
710
711
712
713
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

714
715
716
717
718
719
720
721
722
723
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
724
                    attn_backend_override=attn_backend_override,
725
726
727
728
                )
                for layer_idx in range(depth)
            ]
        )
729
730
731
732
733
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
734
            prefix=f"{prefix}.merger",
735
            use_data_parallel=use_data_parallel,
736
        )
737
        self.attn_backend = get_vit_attn_backend(
738
739
740
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
741
742
743
744
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
745
            self.attn_backend = _Backend.FLASH_ATTN
746
747
748

    @property
    def dtype(self) -> torch.dtype:
749
        return self.patch_embed.proj.weight.dtype
750
751
752

    @property
    def device(self) -> torch.device:
753
        return self.patch_embed.proj.weight.device
754

755
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
756
        pos_ids = []
757
        max_grid_size = 0
758
759
760
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
782
            max_grid_size = max(max_grid_size, h, w)
783
784
785
786
787
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

788
    def compute_attn_mask_seqlen(
789
        self, cu_seqlens: torch.Tensor
790
    ) -> tuple[int | None, list[int] | None]:
791
        max_seqlen, seqlens = None, None
792
        if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
793
794
795
796
797
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

798
799
800
    def forward(
        self,
        x: torch.Tensor,
801
        grid_thw: torch.Tensor | list[list[int]],
802
803
804
805
806
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

807
808
809
810
811
812
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
        else:
            grid_thw_list = grid_thw.tolist()

813
        # compute position embedding
814
        rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
815
816

        # compute cu_seqlens
817
        cu_seqlens = torch.repeat_interleave(
818
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
819
        ).cumsum(dim=0, dtype=torch.int32)
820
821
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
822
823
824

        # transformers
        x = x.unsqueeze(1)
825

826
827
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
828
        for blk in self.blocks:
829
830
831
832
833
834
835
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
836
837
838

        # adapter
        x = self.merger(x)
839

840
841
        return x

842
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
843
844
845
846
847
848
849
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
850
        loaded_params: set[str] = set()
851
852

        for name, loaded_weight in weights:
853
854
855
            if name.endswith("patch_embed.proj.weight"):
                loaded_weight = conv3d_to_linear_weight(loaded_weight)

856
            for param_name, weight_name, shard_id in stacked_params_mapping:
857
858
859
860
861
862
863
864
865
866
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
867
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
868
869
870
871
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

872

873
def _create_qwen2vl_field_factory(
874
    spatial_merge_size: int,
875
876
) -> Callable[
    [Mapping[str, torch.Tensor]],
877
    Mapping[str, MultiModalFieldConfig],
878
879
880
881
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
882
883
884
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
885
886
887

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
888
889
890
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
891
892
893

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
894
895
                "image", image_pixel_grid_sizes
            ),
896
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
897
898
                "image", image_embed_grid_sizes
            ),
899
900
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
901
902
                "video", video_grid_sizes
            ),
903
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
904
905
                "video", video_embed_grid_sizes
            ),
906
907
908
909
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
910

911

Roger Wang's avatar
Roger Wang committed
912
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
913
914
915
916
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

917
918
    def _parse_image_data(
        self,
919
920
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
921
        if isinstance(data, dict):
922
923
924
925
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
926
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
927
            )
928
929
930
931

        return super()._parse_image_data(data)

    def _parse_video_data(
932
        self,
933
934
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
935
        if isinstance(data, dict):
936
937
938
939
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
940
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
941
            )
942
943
944
945

        return super()._parse_video_data(data)


946
947
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
948
949
        return self.ctx.get_hf_config(Qwen2VLConfig)

950
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
951
952
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
953
            use_fast=kwargs.pop("use_fast", True),
954
955
956
            **kwargs,
        )

957
958
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
959

960
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
961
962
        return {"image": None, "video": None}

963
964
965
966
967
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
968
969
970
971
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

972
973
974
975
976
977
978
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
979
        image_processor: Qwen2VLImageProcessor | None,
980
    ) -> tuple[ImageSize, int]:
981
982
983
984
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
985
        vision_config = hf_config.vision_config
986
987
988
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
989

990
991
992
993
994
995
996
997
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
998
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
999
        else:
1000
            preprocessed_size = ImageSize(width=image_width, height=image_height)
1001

1002
1003
1004
1005
1006
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
1007
1008
1009
1010
1011
1012
1013
1014
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

1015
    def get_num_image_tokens(
1016
1017
1018
1019
        self,
        *,
        image_width: int,
        image_height: int,
1020
        image_processor: Qwen2VLImageProcessor | None,
1021
1022
1023
1024
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
1025
            num_frames=1,
1026
            image_processor=image_processor,
1027
1028
1029
        )
        return num_image_tokens

1030
    def get_num_video_tokens(
1031
1032
1033
1034
1035
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1036
        image_processor: Qwen2VLImageProcessor | None,
1037
1038
1039
1040
1041
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1042
            image_processor=image_processor,
1043
1044
1045
        )
        return num_video_tokens

1046
    def get_image_size_with_most_features(self) -> ImageSize:
1047
1048
1049
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1050
            num_frames=1,
1051
            image_processor=None,
1052
1053
1054
        )
        return max_image_size

1055
1056
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1057

1058
        return self.get_num_image_tokens(
1059
1060
            image_width=target_width,
            image_height=target_height,
1061
            image_processor=None,
1062
        )
1063

1064
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1065
        target_width, target_height = self.get_image_size_with_most_features()
1066

1067
        num_frames = start_num_frames
1068
1069
1070

        while True:
            next_num_frames = num_frames + 1
1071
            next_max_tokens = self.get_num_video_tokens(
1072
1073
1074
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1075
                image_processor=None,
1076
            )
1077

1078
            if next_max_tokens > max_tokens:
1079
1080
1081
1082
1083
1084
                break

            num_frames = next_num_frames

        return num_frames

1085
1086
1087
1088
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1089
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1090
1091
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1092

1093
        max_total_frames = self._get_max_video_frames(seq_len)
1094
1095
1096
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1097

1098
        return max(max_frames_per_video, 1)
1099

1100
1101
1102
1103
1104
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1105
        target_width, target_height = self.get_image_size_with_most_features()
1106

1107
        return self.get_num_video_tokens(
1108
1109
            image_width=target_width,
            image_height=target_height,
1110
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1111
            image_processor=None,
1112
1113
        )

1114
1115

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1116
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1117
1118
1119
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1120
        hf_processor = self.info.get_hf_processor()
1121
1122
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1123

1124
1125
1126
1127
1128
1129
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1130
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1131
1132
1133
1134
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1135
1136
1137
1138
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1139

1140
1141
1142
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1143
        return {
1144
1145
1146
1147
1148
1149
1150
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1151
1152
                width=target_width,
                height=target_height,
1153
                num_frames=target_num_frames,
1154
                num_videos=num_videos,
1155
                overrides=video_overrides,
1156
            ),
1157
1158
        }

1159

1160
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1161
    def _get_data_parser(self) -> MultiModalDataParser:
1162
        return Qwen2VLMultiModalDataParser(
1163
1164
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1165

1166
    def _get_prompt_updates(
1167
1168
        self,
        mm_items: MultiModalDataItems,
1169
        hf_processor_mm_kwargs: Mapping[str, Any],
1170
        out_mm_kwargs: MultiModalKwargsItems,
1171
    ) -> Sequence[PromptUpdate]:
1172
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1173
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1174
1175
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1176
1177

        placeholder = {
1178
1179
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1180
        }
1181

1182
1183
1184
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1185
1186
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1187
1188
            assert isinstance(grid_thw, torch.Tensor)

1189
1190
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1191
1192
1193
1194

        return [
            PromptReplacement(
                modality=modality,
1195
                target=[placeholder[modality]],
1196
1197
1198
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1199
        ]
1200

1201
1202
1203
1204
1205
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1206
        return _create_qwen2vl_field_factory(
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1219
    merge_by_field_config = True
1220
    multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
1221

1222
    # To ensure correct weight loading and mapping.
1223
1224
1225
1226
1227
1228
1229
1230
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1231
1232
        }
    )
1233

1234
1235
    supports_encoder_tp_data = True

1236
1237
1238
1239
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
1240
1241
1242
        image_grid_thw: list[list[int]] | torch.Tensor | None,
        video_grid_thw: list[list[int]] | torch.Tensor | None,
        second_per_grid_ts: list[float] | None = None,
1243
        context_len: int = 0,
1244
1245
        seq_len: int | None = None,
        audio_feature_lengths: torch.Tensor | None = None,
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1260
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1261
1262
1263

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1264
1265
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1313
1314
1315
1316
1317
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1318
1319
            text_len = ed - st

1320
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1321
            llm_pos_ids_list.append(
1322
1323
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1324

1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1336

1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1349
            llm_pos_ids_list.append(
1350
1351
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1352
1353
1354
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1355
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1356
1357
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1358
1359
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1360
1361

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1362
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1363
1364
1365
1366
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1367
    @classmethod
1368
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1369
1370
1371
1372
1373
1374
1375
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1376
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1377
        super().__init__()
1378
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1379
1380
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1381

1382
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1383
1384
1385
        self.config = config
        self.multimodal_config = multimodal_config

1386
1387
1388
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1389
1390
1391
1392
1393
            attn_backend_override = (
                multimodal_config.mm_encoder_attn_backend
                if multimodal_config is not None
                else None
            )
1394
1395
1396
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1397
                quant_config=quant_config,
1398
                prefix=maybe_prefix(prefix, "visual"),
1399
                use_data_parallel=self.use_data_parallel,
1400
                attn_backend_override=attn_backend_override,
1401
1402
1403
            )
        else:
            self.visual = None
1404

1405
1406
1407
1408
1409
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1410

1411
        self.make_empty_intermediate_tensors = (
1412
1413
            self.language_model.make_empty_intermediate_tensors
        )
1414
1415

    def _parse_and_validate_image_input(
1416
        self, **kwargs: object
1417
    ) -> Qwen2VLImageInputs | None:
1418
        pixel_values = kwargs.pop("pixel_values", None)
1419
        image_embeds = kwargs.pop("image_embeds", None)
1420
1421
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1422
        if pixel_values is None and image_embeds is None:
1423
1424
            return None

1425
        if pixel_values is not None:
1426
1427
1428
1429
1430
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1431
1432

        if image_embeds is not None:
1433
1434
1435
1436
1437
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1438
1439

    def _parse_and_validate_video_input(
1440
        self, **kwargs: object
1441
    ) -> Qwen2VLVideoInputs | None:
1442
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1443
        video_embeds = kwargs.pop("video_embeds", None)
1444
1445
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1446
        if pixel_values_videos is None and video_embeds is None:
1447
1448
            return None

1449
1450
1451
1452
1453
1454
1455
1456
        if pixel_values_videos is not None:
            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
1457
1458
1459
1460
1461
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1462

1463
    def _process_image_input(
1464
1465
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1466
1467
1468
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1469
        if image_input["type"] == "image_embeds":
1470
            image_embeds = image_input["image_embeds"]
1471
        else:
1472
            pixel_values = image_input["pixel_values"]
1473
1474

            if self.use_data_parallel:
1475
                return run_dp_sharded_mrope_vision_model(
1476
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
1477
                )
1478
            else:
1479
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
1480
1481
1482

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1483
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1484
        return image_embeds.split(sizes)
1485
1486

    def _process_video_input(
1487
1488
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1489
1490
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1491

1492
        if video_input["type"] == "video_embeds":
1493
            video_embeds = video_input["video_embeds"]
1494
        else:
1495
            pixel_values_videos = video_input["pixel_values_videos"]
1496
            if self.use_data_parallel:
1497
                grid_thw_list = grid_thw.tolist()
1498
1499
1500
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1501
            else:
1502
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1503

1504
1505
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1506
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1507
        return video_embeds.split(sizes)
1508
1509
1510
1511
1512
1513
1514

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1525
1526

        return modalities
1527

1528
1529
1530
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1531
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1532
1533
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1534
            return []
1535

1536
1537
1538
1539
1540
1541
1542
1543
1544
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1545
1546
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1547
1548
1549
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1550
                multimodal_embeddings += tuple(video_embeddings)
1551
1552
1553

        return multimodal_embeddings

1554
1555
1556
1557
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1558
1559
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1560
        **kwargs: object,
1561
    ) -> torch.Tensor | IntermediateTensors:
1562
1563
1564
1565
1566
1567
1568
1569
1570
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1571
1572
1573
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1574
        """
1575

1576
        if intermediate_tensors is not None:
1577
            inputs_embeds = None
1578

1579
        hidden_states = self.language_model.model(
1580
1581
            input_ids=input_ids,
            positions=positions,
1582
            intermediate_tensors=intermediate_tensors,
1583
1584
1585
1586
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1587
1588
1589
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1590
    ) -> torch.Tensor | None:
1591
        return self.language_model.compute_logits(hidden_states)
1592

1593
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1594
1595
1596
1597
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1598
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1599
1600
1601
1602
1603
1604
1605

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1606
1607
1608
            connector="visual.merger.",
            tower_model="visual.",
        )
1609
1610
1611
1612
1613
1614
1615
1616
1617


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1618
        size: dict[str, int] | None = None,
1619
1620
1621
1622
1623
1624
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1625
                "longest_edge": size["max_pixels"],
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1640
1641
1642
1643
1644
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1645
1646
            **kwargs,
        )
1647
1648
1649
1650
1651


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
1652
        correct_config = Qwen2VLConfig.from_pretrained(model_path)
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1664
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1665
1666


1667
1668
1669
1670
1671
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1672
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1673
1674
1675
1676
1677
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1688
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1689
1690
1691
1692
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1693
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)