qwen2_vl.py 63.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27
from collections.abc import Iterable, Mapping, Sequence
28
from functools import partial
29
from typing import Annotated, Any, Callable, Literal, Optional, Union
30
31
32
33
34

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
35
from transformers import AutoConfig, BatchFeature, PretrainedConfig
36
37
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
                                          Qwen2VLProcessor)
38
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
    Qwen2VLConfig, Qwen2VLVisionConfig)
40
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
41
42
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
    Qwen2VLVideoProcessor)
43

44
from vllm.attention.layer import check_upstream_fa_availability
45
from vllm.config import VllmConfig
46
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
47
48
49
50
51
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
52
53
54
55
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
    GPTQMarlinConfig)
56
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
57
from vllm.model_executor.models.module_mapping import MultiModelKeys
58
from vllm.multimodal import MULTIMODAL_REGISTRY
59
from vllm.multimodal.inputs import (ImageItem, ModalityData,
60
                                    MultiModalDataDict, MultiModalFieldConfig,
61
                                    MultiModalKwargsItems, VideoItem)
62
63
64
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
                                   ModalityDataItems, MultiModalDataItems,
                                   MultiModalDataParser)
65
from vllm.multimodal.processing import (BaseMultiModalProcessor,
66
67
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
68
from vllm.multimodal.profiling import BaseDummyInputsBuilder
69
from vllm.platforms import _Backend, current_platform
70
from vllm.sequence import IntermediateTensors
71
from vllm.transformers_utils.config import uses_mrope
72
from vllm.transformers_utils.tokenizer import AnyTokenizer
73
from vllm.utils.tensor_schema import TensorSchema, TensorShape
74

75
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
76
                         SupportsMultiModal, SupportsPP)
77
from .utils import (AutoWeightsLoader, WeightsMapper,
78
79
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
80
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
81

82
83
logger = init_logger(__name__)

84
# For profile run
85
_MAX_FRAMES_PER_VIDEO = 600
86

87
88
89
# === Vision Inputs === #


90
class Qwen2VLImagePixelInputs(TensorSchema):
91
    """
92
93
94
95
96
97
98
99
100
101
102
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    
    Historical context:
        - pixel_values shape: (num_patches, num_channels * patch_size * 
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
103
    """
104
    type: Literal["pixel_values"]
105

106
107
108
109
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
110

111
112
113
114
115
116
117
118
119
120
121
122
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
123
    
124
125
126
127
128
129
130
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
131
    """
132
    type: Literal["image_embeds"]
133

134
135
136
137
138
139
140
141
142
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
143
144
145
146
147
148


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
                           Qwen2VLImageEmbeddingInputs]


149
150
151
152
153
154
155
156
157
158
159
160
161
162
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
        - ctps: Number of channels * temporal_patch_size * patch_size * 
          patch_size
        - nv: Number of videos
    
    Historical context:
        - pixel_values_videos shape: (num_patches, num_channels * 
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
163
    """
164
    type: Literal["pixel_values_videos"]
165

166
167
168
169
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
170

171
172
173
174
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
175
176


177
178
179
180
181
182
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
183
    
184
185
186
187
188
189
190
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
191
    """
192
    type: Literal["video_embeds"]
193

194
195
196
197
198
199
200
201
202
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
203
204
205
206
207


Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
                           Qwen2VLVideoEmbeddingInputs]

208
209
210
211
212
213
214
215
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
216
        hidden_features: int,
217
        act_layer: type[nn.Module] = QuickGELU,
218
        quant_config: Optional[QuantizationConfig] = None,
219
        prefix: str = "",
220
        use_data_parallel: bool = False,
221
222
223
224
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(in_features,
                                        hidden_features,
225
                                        quant_config=quant_config,
226
227
                                        prefix=f"{prefix}.fc1",
                                        disable_tp=use_data_parallel)
228
229
230
        self.act = act_layer()
        self.fc2 = RowParallelLinear(hidden_features,
                                     in_features,
231
                                     quant_config=quant_config,
232
233
                                     prefix=f"{prefix}.fc2",
                                     disable_tp=use_data_parallel)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)


def apply_rotary_emb_torch(x: torch.Tensor,
                           cos: torch.Tensor,
                           sin: torch.Tensor,
                           interleaved: bool = False) -> torch.Tensor:
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )


def apply_rotary_pos_emb_vision(t: torch.Tensor,
279
                                freqs: torch.Tensor) -> torch.Tensor:
280
281
282
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
燃's avatar
committed
283
    apply_rotary_emb = apply_rotary_emb_torch
284
285
    if current_platform.is_cuda():
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
燃's avatar
committed
286
    output = apply_rotary_emb(t_, cos, sin).type_as(t)
287
288
289
290
291
292
293
    return output


class Qwen2VisionAttention(nn.Module):

    def __init__(
        self,
294
295
296
        embed_dim: int,
        num_heads: int,
        projection_size: int,
297
        quant_config: Optional[QuantizationConfig] = None,
298
        prefix: str = "",
299
        use_data_parallel: bool = False,
300
301
302
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
303
304
        self.tp_size = (1 if use_data_parallel else
                        parallel_state.get_tensor_model_parallel_world_size())
305
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
306
307
308
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
309
            num_heads, self.tp_size)
310
311
312

        self.qkv = ColumnParallelLinear(input_size=embed_dim,
                                        output_size=3 * projection_size,
313
                                        quant_config=quant_config,
314
315
                                        prefix=f"{prefix}.qkv",
                                        disable_tp=use_data_parallel)
316
317
        self.proj = RowParallelLinear(input_size=projection_size,
                                      output_size=embed_dim,
318
                                      quant_config=quant_config,
319
320
                                      prefix=f"{prefix}.proj",
                                      disable_tp=use_data_parallel)
321
322

        # Detect attention implementation.
323
324
325
326
327
328
329
330
331
332
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
            dtype=torch.get_default_dtype())
        self.use_upstream_fa = False
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(
                torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
            self.use_upstream_fa = True

333
        if self.attn_backend not in {
334
335
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
                _Backend.ROCM_AITER_FA
336
337
338
        }:
            raise RuntimeError(
                f"Qwen2-VL does not support {self.attn_backend} backend now.")
339
340
341
        self.is_flash_attn_backend = self.attn_backend in {
            _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
        }
342

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(dist_utils.split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
                     self.hidden_size_per_attention_head)
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

366
    def forward(
367
368
369
370
371
372
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
373
374
    ) -> torch.Tensor:

375
376
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
377

378
379
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
380
381
        batch_size = q.shape[1]

382
383
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
384
        if rotary_pos_emb is not None:
385
386
387
388
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
389

390
391
392
393
        if self.is_flash_attn_backend:
            if self.attn_backend == _Backend.ROCM_AITER_FA:
                from aiter import flash_attn_varlen_func
            else:
394
395
396
397
                if self.use_upstream_fa:
                    from flash_attn import flash_attn_varlen_func
                else:
                    from vllm.vllm_flash_attn import flash_attn_varlen_func
398

399
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
400
401
402
403
404
405
406
407

            output = flash_attn_varlen_func(q,
                                            k,
                                            v,
                                            cu_seqlens_q=cu_seqlens,
                                            cu_seqlens_k=cu_seqlens,
                                            max_seqlen_q=max_seqlen,
                                            max_seqlen_k=max_seqlen,
408
                                            dropout_p=0.0,
409
410
411
                                            causal=False)

            context_layer = rearrange(output,
412
413
                                      "(b s) h d -> s b (h d)",
                                      b=batch_size).contiguous()
414
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
415
416
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
417
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                                 for x in [q_i, k_i, v_i])
                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
432
433
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
434
        elif self.attn_backend == _Backend.XFORMERS:
435
436
437
438
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
439
440
                                                       kv_seqlen=None,
                                                       device=q.device)
441
442
443

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
444
445
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
446
447
448
449
450
451
452
453
454
455
456
457

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
458
        act_layer: type[nn.Module] = QuickGELU,
459
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
460
        quant_config: Optional[QuantizationConfig] = None,
461
        prefix: str = "",
462
        use_data_parallel: bool = False,
463
464
465
466
467
468
469
470
471
472
473
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.attn = Qwen2VisionAttention(embed_dim=dim,
                                         num_heads=num_heads,
                                         projection_size=dim,
474
                                         quant_config=quant_config,
475
476
                                         prefix=f"{prefix}.attn",
                                         use_data_parallel=use_data_parallel)
477
478
479
        self.mlp = Qwen2VisionMLP(dim,
                                  mlp_hidden_dim,
                                  act_layer=act_layer,
480
                                  quant_config=quant_config,
481
482
                                  prefix=f"{prefix}.mlp",
                                  use_data_parallel=use_data_parallel)
483

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

500
501
502
503
504
505
506
507
508
509
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
510
        in_channels: int = 3,
511
512
513
514
515
516
517
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

518
519
        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(in_channels,
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
539
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
540
541
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
542
        prefix: str = "",
543
        use_data_parallel: bool = False,
544
545
546
547
548
549
550
551
552
553
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
        self.mlp = nn.ModuleList([
            ColumnParallelLinear(self.hidden_size,
                                 self.hidden_size,
                                 bias=True,
554
                                 quant_config=quant_config,
555
556
                                 prefix=f"{prefix}.mlp.0",
                                 disable_tp=use_data_parallel),
557
558
559
560
            nn.GELU(),
            RowParallelLinear(self.hidden_size,
                              d_model,
                              bias=True,
561
                              quant_config=quant_config,
562
563
                              prefix=f"{prefix}.mlp.2",
                              disable_tp=use_data_parallel),
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
                                                / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
614
        prefix: str = "",
615
        use_data_parallel: bool = False,
616
617
618
    ) -> None:
        super().__init__()

619
620
621
622
623
624
625
626
627
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
628

629
630
631
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

632
        self.spatial_merge_size = spatial_merge_size
633
634
        self.num_heads = num_heads
        self.embed_dim = embed_dim
635
636
637
638

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
639
            in_channels=in_channels,
640
641
642
643
644
645
646
647
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([
648
649
650
651
652
            Qwen2VisionBlock(dim=embed_dim,
                             num_heads=num_heads,
                             mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer,
                             quant_config=quant_config,
653
654
                             prefix=f"{prefix}.blocks.{layer_idx}",
                             use_data_parallel=use_data_parallel)
655
            for layer_idx in range(depth)
656
657
658
659
660
661
        ])
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
662
            prefix=f"{prefix}.merger",
663
            use_data_parallel=use_data_parallel,
664
        )
665
666
667
668
669
670
        self.attn_backend = get_vit_attn_backend(
            head_size=head_dim, dtype=torch.get_default_dtype())
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(
                torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
671
672
673

    @property
    def dtype(self) -> torch.dtype:
674
        return self.patch_embed.proj.weight.dtype
675
676
677

    @property
    def device(self) -> torch.device:
678
        return self.patch_embed.proj.weight.device
679

680
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
681
        pos_ids = []
682
        max_grid_size = 0
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
700
            max_grid_size = max(max_grid_size, h, w)
701
702
703
704
705
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

706
707
708
709
    def compute_attn_mask_seqlen(
            self, cu_seqlens: torch.Tensor
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
710
711
        if (self.attn_backend == _Backend.FLASH_ATTN
                or self.attn_backend == _Backend.ROCM_AITER_FA):
712
713
714
715
716
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

717
718
719
    def forward(
        self,
        x: torch.Tensor,
720
        grid_thw: list[list[int]],
721
722
723
724
725
726
727
728
729
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
730
731
732
        grid_thw_ = torch.tensor(grid_thw)
        cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
                                             grid_thw_[:, 0]).cumsum(
733
734
735
736
737
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
738

739
740
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
741
        for blk in self.blocks:
742
743
744
745
746
747
748
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
749
750
751

        # adapter
        x = self.merger(x)
752

753
754
        return x

755
756
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
757
758
759
760
761
762
763
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
764
        loaded_params: set[str] = set()
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

784

785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def _create_qwen2vl_field_factory(
    spatial_merge_size: int
) -> Callable[
    [Mapping[str, torch.Tensor]],
        Mapping[str, MultiModalFieldConfig],
]:

    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
        image_embed_grid_sizes = (image_pixel_grid_sizes //
                                  spatial_merge_size // spatial_merge_size)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
        video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
                                  spatial_merge_size)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", image_pixel_grid_sizes),
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
                "image", image_embed_grid_sizes),
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
                "video", video_grid_sizes),
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
                "video", video_embed_grid_sizes),
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
817

818

Roger Wang's avatar
Roger Wang committed
819
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
820

821
822
823
824
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

825
826
827
    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
828
    ) -> Optional[ModalityDataItems[Any, Any]]:
829
        if isinstance(data, dict):
830
831
832
833
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
834
835
                fields_factory=_create_qwen2vl_field_factory(
                    self._spatial_merge_size),
836
            )
837
838
839
840

        return super()._parse_image_data(data)

    def _parse_video_data(
841
        self,
842
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
843
    ) -> Optional[ModalityDataItems[Any, Any]]:
844
        if isinstance(data, dict):
845
846
847
848
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
849
850
                fields_factory=_create_qwen2vl_field_factory(
                    self._spatial_merge_size),
851
            )
852
853
854
855

        return super()._parse_video_data(data)


856
class Qwen2VLProcessingInfo(BaseProcessingInfo):
857

858
    def get_hf_config(self):
859
860
        return self.ctx.get_hf_config(Qwen2VLConfig)

861
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
862
863
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
864
            use_fast=kwargs.pop("use_fast", True),
865
866
867
            **kwargs,
        )

868
869
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
870

871
872
873
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

874
875
876
877
878
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
879
880
881
882
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

883
884
885
886
887
888
889
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
890
        image_processor: Optional[Qwen2VLImageProcessor],
891
    ) -> tuple[ImageSize, int]:
892
893
894
895
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
896
        vision_config = hf_config.vision_config
897
898
899
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

915
916
917
918
919
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
920
921
922
923
924
925
926
927
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

928
    def get_num_image_tokens(
929
930
931
932
        self,
        *,
        image_width: int,
        image_height: int,
933
        image_processor: Optional[Qwen2VLImageProcessor],
934
935
936
937
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
938
            image_processor=image_processor,
939
940
941
        )
        return num_image_tokens

942
    def get_num_video_tokens(
943
944
945
946
947
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
948
        image_processor: Optional[Qwen2VLImageProcessor],
949
950
951
952
953
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
954
            image_processor=image_processor,
955
956
957
        )
        return num_video_tokens

958
    def get_image_size_with_most_features(self) -> ImageSize:
959
960
961
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
962
            image_processor=None,
963
964
965
        )
        return max_image_size

966
967
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
968

969
        return self.get_num_image_tokens(
970
971
            image_width=target_width,
            image_height=target_height,
972
            image_processor=None,
973
        )
974
975

    def _get_max_video_frames(self, max_tokens: int) -> int:
976
        target_width, target_height = self.get_image_size_with_most_features()
977

978
979
980
981
        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
982
            next_max_tokens = self.get_num_video_tokens(
983
984
985
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
986
                image_processor=None,
987
            )
988

989
            if next_max_tokens > max_tokens:
990
991
992
993
994
995
                break

            num_frames = next_num_frames

        return num_frames

996
997
998
999
1000
1001
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1002

1003
        max_total_frames = self._get_max_video_frames(seq_len)
1004
1005
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
1006

1007
        return max(max_frames_per_video, 1)
1008

1009
1010
1011
1012
1013
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1014
        target_width, target_height = self.get_image_size_with_most_features()
1015

1016
        return self.get_num_video_tokens(
1017
1018
            image_width=target_width,
            image_height=target_height,
1019
1020
            num_frames=self.get_num_frames_with_most_features(
                seq_len, mm_counts),
1021
            image_processor=None,
1022
1023
        )

1024
1025
1026

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):

1027
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1028
1029
1030
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1031
        hf_processor = self.info.get_hf_processor()
1032
1033
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1034

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1045
1046
1047
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = \
1048
            self.info.get_num_frames_with_most_features(seq_len, mm_counts)
1049

1050
        return {
1051
1052
1053
1054
1055
1056
1057
1058
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
1059
                num_frames=target_num_frames,
1060
1061
                num_videos=num_videos,
            )
1062
1063
        }

1064

1065
1066
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
                                 ):
1067

1068
    def _get_data_parser(self) -> MultiModalDataParser:
1069
1070
        return Qwen2VLMultiModalDataParser(
            self.info.get_hf_config().vision_config.spatial_merge_size)
1071

1072
    def _get_prompt_updates(
1073
1074
        self,
        mm_items: MultiModalDataItems,
1075
        hf_processor_mm_kwargs: Mapping[str, Any],
1076
        out_mm_kwargs: MultiModalKwargsItems,
1077
    ) -> Sequence[PromptUpdate]:
1078
1079
1080
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)
1081
1082
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1083
1084

        placeholder = {
1085
1086
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1087
        }
1088

1089
1090
1091
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1092
1093
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1094
1095
            assert isinstance(grid_thw, torch.Tensor)

1096
1097
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1098
1099
1100
1101

        return [
            PromptReplacement(
                modality=modality,
1102
                target=[placeholder[modality]],
1103
1104
1105
1106
                replacement=partial(get_replacement_qwen2vl,
                                    modality=modality),
            ) for modality in ("image", "video")
        ]
1107

1108
1109
1110
1111
1112
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1113
1114
1115
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size)(
                hf_inputs)
1116

1117

1118
1119
1120
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
                                        info=Qwen2VLProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
1121
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1122
                                      SupportsLoRA, SupportsPP, SupportsMRoPE):
1123

1124
    # To ensure correct weight loading and mapping.
1125
1126
1127
1128
1129
1130
1131
1132
1133
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })
1134

1135
1136
    supports_encoder_tp_data = True

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        second_per_grid_ts: Optional[list[float]] = None,
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(hf_config.vision_config,
                                    "tokens_per_second", 1.0)

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
            input_tokens_tensor == vision_start_token_id).squeeze(1)
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

            llm_grid_t, llm_grid_h, llm_grid_w = \
                t, h // spatial_merge_size, w // spatial_merge_size
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
                -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                       tokens_per_second).long().flatten()

            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                llm_grid_t, -1, llm_grid_w).flatten()
            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                llm_grid_t, llm_grid_h, -1).flatten()
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1249
1250
1251
1252
1253
1254
1255
1256
1257
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1258
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1259
        super().__init__()
1260
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1261
1262
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1263

1264
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1265
1266
1267
        self.config = config
        self.multimodal_config = multimodal_config

1268
1269
1270
1271
1272
1273
1274
        if multimodal_config.get_limit_per_prompt("image") or \
            multimodal_config.get_limit_per_prompt("video"):
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=self._maybe_ignore_quant_config(quant_config),
                prefix=maybe_prefix(prefix, "visual"),
1275
                use_data_parallel=self.use_data_parallel,
1276
1277
1278
            )
        else:
            self.visual = None
1279

1280
1281
1282
1283
1284
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1285

1286
1287
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
1288

1289
1290
1291
1292
1293
1294
1295
1296
    def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
        # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
        # seems to avoid vision encoder sections for some models.
        # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
        if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
            return None
        return quant_config

1297
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
1298
1299
1300
1301
1302
1303
1304
1305
1306
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1307
1308
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
1309
            return mm_input.reshape(-1, mm_input.shape[-1])
1310
1311
1312
1313
1314
1315
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
1316
        image_embeds = kwargs.pop("image_embeds", None)
1317
1318
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1319
        if pixel_values is None and image_embeds is None:
1320
1321
            return None

1322
1323
1324
1325
1326
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1327

1328
            return Qwen2VLImagePixelInputs(type="pixel_values",
1329
                                           pixel_values=pixel_values,
1330
1331
1332
                                           image_grid_thw=image_grid_thw)

        if image_embeds is not None:
1333
1334
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
1335
1336
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1337

1338
            return Qwen2VLImageEmbeddingInputs(type="image_embeds",
1339
1340
                                               image_embeds=image_embeds,
                                               image_grid_thw=image_grid_thw)
1341
1342
1343
1344

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1345
        video_embeds = kwargs.pop("video_embeds", None)
1346
1347
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1348
        if pixel_values_videos is None and video_embeds is None:
1349
1350
            return None

1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
                                               video_embeds=video_embeds,
                                               video_grid_thw=video_grid_thw)
1372

1373
1374
1375
1376
1377
    def _process_image_input(
            self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:

        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1378
        grid_thw_list = grid_thw.tolist()
1379

1380
        if image_input["type"] == "image_embeds":
1381
            image_embeds = image_input["image_embeds"]
1382
        else:
1383
            pixel_values = image_input["pixel_values"]
1384
1385
1386
1387
1388
1389
1390
1391
1392

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values,
                                                         grid_thw_list,
                                                         rope_type="rope_3d")
            else:
                image_embeds = self.visual(pixel_values,
                                           grid_thw=grid_thw_list)
1393
1394
1395

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1396
1397
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
1398

1399
        return image_embeds.split(sizes)
1400
1401
1402

    def _process_video_input(
            self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
1403

1404
1405
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1406
        grid_thw_list = grid_thw.tolist()
1407

1408
        if video_input["type"] == "video_embeds":
1409
            video_embeds = video_input["video_embeds"]
1410
        else:
1411
            pixel_values_videos = video_input["pixel_values_videos"]
1412
1413
1414
1415
1416
1417
1418
1419
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values_videos,
                                                         grid_thw_list,
                                                         rope_type="rope_3d")
            else:
                video_embeds = self.visual(pixel_values_videos,
                                           grid_thw=grid_thw_list)
1420

1421
1422
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1423
1424
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
1425

1426
        return video_embeds.split(sizes)
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key in ("pixel_values_videos",
                             "video_embeds") and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities
1444

1445
1446
1447
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1448
1449
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
1450

1451
1452
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1453
            return []
1454

1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
                multimodal_embeddings += video_embeddings
1470
1471
1472
1473
1474
1475

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
1476
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1477
    ) -> torch.Tensor:
1478
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1479
1480
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
1481
1482
1483
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [self.config.image_token_id, self.config.video_token_id])
1484
1485
        return inputs_embeds

1486
1487
1488
    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
1489
1490
        image_input: Optional[Qwen2VLImagePixelInputs] = None,
        video_input: Optional[Qwen2VLVideoPixelInputs] = None,
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=self.config.image_token_id,
            )

        if video_input is not None:
            video_embeds = self._process_video_input(video_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                video_embeds,
                placeholder_token_id=self.config.video_token_id,
            )
        return inputs_embeds

1512
1513
1514
1515
1516
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1517
        inputs_embeds: Optional[torch.Tensor] = None,
1518
        **kwargs: object,
1519
    ) -> Union[torch.Tensor, IntermediateTensors]:
1520
1521
1522
1523
1524
1525
1526
1527
1528
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1529
1530
1531
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1532
        """
1533

1534
        if intermediate_tensors is not None:
1535
            inputs_embeds = None
1536

1537
1538
1539
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
1540
        elif inputs_embeds is None:
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
            image_input = self._parse_and_validate_image_input(**kwargs)
            video_input = self._parse_and_validate_video_input(**kwargs)

            if image_input is None and video_input is None:
                inputs_embeds = None
            else:
                if uses_mrope(self.config):
                    assert positions.ndim == 2 and positions.size(0) == 3, (
                        "multimodal section rotary embedding requires "
                        f"(3, seq_len) positions, but got {positions.size()}")
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    video_input=video_input)
                input_ids = None
1556

1557
        hidden_states = self.language_model.model(
1558
1559
            input_ids=input_ids,
            positions=positions,
1560
            intermediate_tensors=intermediate_tensors,
1561
1562
1563
1564
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1565
1566
1567
1568
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
1569
        return self.language_model.compute_logits(hidden_states)
1570

1571
1572
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1573

1574
1575
1576
1577
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1578
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1579
1580
1581
1582
1583
1584
1585

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1586
1587
1588
            connector="visual.merger.",
            tower_model="visual.",
        )
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):

    def __init__(
        self,
        size: Optional[dict[str, int]] = None,
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
                "longest_edge": size["max_pixels"]
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):

    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1622
1623
1624
1625
1626
1627
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
            **kwargs)
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):

    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
        return Tarsier2ImageProcessor(
            **self.ctx.get_hf_image_processor_config())


@MULTIMODAL_REGISTRY.register_processor(Tarsier2MultiModalProcessor,
                                        info=Tarsier2ProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "vision_tower.": "visual.",
    })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:

1672
1673
1674
1675
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1676
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)