qwen2_vl.py 56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27
from collections.abc import Iterable, Mapping, Sequence
28
from functools import partial
29
from typing import Annotated, Any, Callable, Literal, Optional, Union
30
31
32
33
34

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
35
from transformers import AutoConfig, BatchFeature
36
37
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
                                          Qwen2VLProcessor)
38
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
    Qwen2VLConfig, Qwen2VLVisionConfig)
40
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
41
42
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
    Qwen2VLVideoProcessor)
43

44
from vllm.config import VllmConfig
45
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
46
47
48
49
50
51
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
52
53
54
55
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
    GPTQMarlinConfig)
56
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
57
from vllm.model_executor.models.module_mapping import MultiModelKeys
58
from vllm.multimodal import MULTIMODAL_REGISTRY
59
from vllm.multimodal.inputs import (ImageItem, ModalityData,
60
                                    MultiModalDataDict, MultiModalFieldConfig,
61
                                    MultiModalKwargsItems, VideoItem)
62
63
64
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
                                   ModalityDataItems, MultiModalDataItems,
                                   MultiModalDataParser)
65
from vllm.multimodal.processing import (BaseMultiModalProcessor,
66
67
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
68
from vllm.multimodal.profiling import BaseDummyInputsBuilder
69
from vllm.platforms import _Backend, current_platform
70
from vllm.sequence import IntermediateTensors
71
from vllm.transformers_utils.config import uses_mrope
72
from vllm.transformers_utils.tokenizer import AnyTokenizer
73
from vllm.utils.tensor_schema import TensorSchema, TensorShape
74

75
76
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP)
77
from .utils import (AutoWeightsLoader, WeightsMapper,
78
79
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
80
from .vision import get_vit_attn_backend
81

82
83
logger = init_logger(__name__)

84
85
86
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

87
88
89
# === Vision Inputs === #


90
class Qwen2VLImagePixelInputs(TensorSchema):
91
    """
92
93
94
95
96
97
98
99
100
101
102
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    
    Historical context:
        - pixel_values shape: (num_patches, num_channels * patch_size * 
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
103
    """
104
    type: Literal["pixel_values"]
105

106
107
108
109
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
110

111
112
113
114
115
116
117
118
119
120
121
122
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
123
    
124
125
126
127
128
129
130
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
131
    """
132
    type: Literal["image_embeds"]
133

134
135
136
137
138
139
140
141
142
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
143
144
145
146
147
148


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
                           Qwen2VLImageEmbeddingInputs]


149
150
151
152
153
154
155
156
157
158
159
160
161
162
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
        - ctps: Number of channels * temporal_patch_size * patch_size * 
          patch_size
        - nv: Number of videos
    
    Historical context:
        - pixel_values_videos shape: (num_patches, num_channels * 
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
163
    """
164
    type: Literal["pixel_values_videos"]
165

166
167
168
169
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
170

171
172
173
174
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
175
176


177
178
179
180
181
182
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
183
    
184
185
186
187
188
189
190
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
191
    """
192
    type: Literal["video_embeds"]
193

194
195
196
197
198
199
200
201
202
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
203
204
205
206
207


Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
                           Qwen2VLVideoEmbeddingInputs]

208
209
210
211
212
213
214
215
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
216
        hidden_features: int,
217
        act_layer: type[nn.Module] = QuickGELU,
218
        quant_config: Optional[QuantizationConfig] = None,
219
        prefix: str = "",
220
221
222
223
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(in_features,
                                        hidden_features,
224
225
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
226
227
228
        self.act = act_layer()
        self.fc2 = RowParallelLinear(hidden_features,
                                     in_features,
229
230
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)


def apply_rotary_emb_torch(x: torch.Tensor,
                           cos: torch.Tensor,
                           sin: torch.Tensor,
                           interleaved: bool = False) -> torch.Tensor:
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )


def apply_rotary_pos_emb_vision(t: torch.Tensor,
276
                                freqs: torch.Tensor) -> torch.Tensor:
277
278
279
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
燃's avatar
committed
280
    apply_rotary_emb = apply_rotary_emb_torch
281
282
    if current_platform.is_cuda():
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
燃's avatar
committed
283
    output = apply_rotary_emb(t_, cos, sin).type_as(t)
284
285
286
287
288
289
290
    return output


class Qwen2VisionAttention(nn.Module):

    def __init__(
        self,
291
292
293
        embed_dim: int,
        num_heads: int,
        projection_size: int,
294
        quant_config: Optional[QuantizationConfig] = None,
295
        prefix: str = "",
296
297
298
299
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
        world_size = parallel_state.get_tensor_model_parallel_world_size()
300
301
        self.tp_size = world_size
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
302
303
304
305
306
307
308
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
            num_heads, world_size)

        self.qkv = ColumnParallelLinear(input_size=embed_dim,
                                        output_size=3 * projection_size,
309
310
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.qkv")
311
312
        self.proj = RowParallelLinear(input_size=projection_size,
                                      output_size=embed_dim,
313
314
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.proj")
315
316

        # Detect attention implementation.
317
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
318
        if self.attn_backend not in {
319
320
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
                _Backend.ROCM_AITER_FA
321
322
323
        }:
            raise RuntimeError(
                f"Qwen2-VL does not support {self.attn_backend} backend now.")
324
325
326
        self.is_flash_attn_backend = self.attn_backend in {
            _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
        }
327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(dist_utils.split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
                     self.hidden_size_per_attention_head)
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

351
    def forward(
352
353
354
355
356
357
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
358
359
    ) -> torch.Tensor:

360
361
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
362

363
364
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
365
366
        batch_size = q.shape[1]

367
368
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
369
370
371
372
        if rotary_pos_emb is not None:
            q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
            k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

373
374
375
376
377
        if self.is_flash_attn_backend:
            if self.attn_backend == _Backend.ROCM_AITER_FA:
                from aiter import flash_attn_varlen_func
            else:
                from flash_attn import flash_attn_varlen_func
378

379
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
380
381
382
383
384
385
386
387

            output = flash_attn_varlen_func(q,
                                            k,
                                            v,
                                            cu_seqlens_q=cu_seqlens,
                                            cu_seqlens_k=cu_seqlens,
                                            max_seqlen_q=max_seqlen,
                                            max_seqlen_k=max_seqlen,
388
                                            dropout_p=0.0,
389
390
391
392
393
                                            causal=False)

            context_layer = rearrange(output,
                                      "(b s) ... -> b s ...",
                                      b=batch_size)
394
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
395
396
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
397
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                                 for x in [q_i, k_i, v_i])
                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
412
        elif self.attn_backend == _Backend.XFORMERS:
413
414
415
416
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
417
418
                                                       kv_seqlen=None,
                                                       device=q.device)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
        context_layer = rearrange(context_layer,
                                  "b s h d -> s b (h d)").contiguous()

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
436
        act_layer: type[nn.Module] = QuickGELU,
437
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
438
        quant_config: Optional[QuantizationConfig] = None,
439
        prefix: str = "",
440
441
442
443
444
445
446
447
448
449
450
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.attn = Qwen2VisionAttention(embed_dim=dim,
                                         num_heads=num_heads,
                                         projection_size=dim,
451
452
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.attn")
453
454
455
        self.mlp = Qwen2VisionMLP(dim,
                                  mlp_hidden_dim,
                                  act_layer=act_layer,
456
457
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.mlp")
458

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

475
476
477
478
479
480
481
482
483
484
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
485
        in_channels: int = 3,
486
487
488
489
490
491
492
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

493
494
        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(in_channels,
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
514
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
515
516
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
517
        prefix: str = "",
518
519
520
521
522
523
524
525
526
527
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
        self.mlp = nn.ModuleList([
            ColumnParallelLinear(self.hidden_size,
                                 self.hidden_size,
                                 bias=True,
528
529
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mlp.0"),
530
531
532
533
            nn.GELU(),
            RowParallelLinear(self.hidden_size,
                              d_model,
                              bias=True,
534
535
                              quant_config=quant_config,
                              prefix=f"{prefix}.mlp.2"),
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
                                                / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
586
        prefix: str = "",
587
588
589
    ) -> None:
        super().__init__()

590
591
592
593
594
595
596
597
598
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
599
600

        self.spatial_merge_size = spatial_merge_size
601
602
        self.num_heads = num_heads
        self.embed_dim = embed_dim
603
604
605
606

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
607
            in_channels=in_channels,
608
609
610
611
612
613
614
615
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([
616
617
618
619
620
621
622
            Qwen2VisionBlock(dim=embed_dim,
                             num_heads=num_heads,
                             mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer,
                             quant_config=quant_config,
                             prefix=f"{prefix}.blocks.{layer_idx}")
            for layer_idx in range(depth)
623
624
625
626
627
628
        ])
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
629
            prefix=f"{prefix}.merger",
630
        )
631
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
632
633
634

    @property
    def dtype(self) -> torch.dtype:
635
        return self.patch_embed.proj.weight.dtype
636
637
638

    @property
    def device(self) -> torch.device:
639
        return self.patch_embed.proj.weight.device
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

666
667
668
669
    def compute_attn_mask_seqlen(
            self, cu_seqlens: torch.Tensor
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
670
671
        if (self.attn_backend == _Backend.FLASH_ATTN
                or self.attn_backend == _Backend.ROCM_AITER_FA):
672
673
674
675
676
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                             grid_thw[:, 0]).cumsum(
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
697

698
699
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
700
        for blk in self.blocks:
701
702
703
704
705
706
707
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
708
709
710

        # adapter
        x = self.merger(x)
711

712
713
        return x

714
715
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
716
717
718
719
720
721
722
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
723
        loaded_params: set[str] = set()
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

743

744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
def _create_qwen2vl_field_factory(
    spatial_merge_size: int
) -> Callable[
    [Mapping[str, torch.Tensor]],
        Mapping[str, MultiModalFieldConfig],
]:

    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
        image_embed_grid_sizes = (image_pixel_grid_sizes //
                                  spatial_merge_size // spatial_merge_size)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
        video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
                                  spatial_merge_size)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", image_pixel_grid_sizes),
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
                "image", image_embed_grid_sizes),
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
                "video", video_grid_sizes),
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
                "video", video_embed_grid_sizes),
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
776

777

Roger Wang's avatar
Roger Wang committed
778
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
779

780
781
782
783
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

784
785
786
    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
787
    ) -> Optional[ModalityDataItems[Any, Any]]:
788
        if isinstance(data, dict):
789
790
791
792
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
793
794
                fields_factory=_create_qwen2vl_field_factory(
                    self._spatial_merge_size),
795
            )
796
797
798
799

        return super()._parse_image_data(data)

    def _parse_video_data(
800
        self,
801
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
802
    ) -> Optional[ModalityDataItems[Any, Any]]:
803
        if isinstance(data, dict):
804
805
806
807
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
808
809
                fields_factory=_create_qwen2vl_field_factory(
                    self._spatial_merge_size),
810
            )
811
812
813
814

        return super()._parse_video_data(data)


815
class Qwen2VLProcessingInfo(BaseProcessingInfo):
816

817
    def get_hf_config(self):
818
819
        return self.ctx.get_hf_config(Qwen2VLConfig)

820
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
821
822
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
823
            use_fast=kwargs.pop("use_fast", True),
824
825
826
            **kwargs,
        )

827
828
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
829

830
831
832
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

833
834
835
836
837
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
838
839
840
841
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

842
843
844
845
846
847
848
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
849
        image_processor: Optional[Qwen2VLImageProcessor],
850
    ) -> tuple[ImageSize, int]:
851
852
853
854
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
855
        vision_config = hf_config.vision_config
856
857
858
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
859

860
861
862
863
864
865
866
867
868
869
870
871
872
873
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

874
875
876
877
878
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
879
880
881
882
883
884
885
886
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

887
    def get_num_image_tokens(
888
889
890
891
        self,
        *,
        image_width: int,
        image_height: int,
892
        image_processor: Optional[Qwen2VLImageProcessor],
893
894
895
896
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
897
            image_processor=image_processor,
898
899
900
        )
        return num_image_tokens

901
    def get_num_video_tokens(
902
903
904
905
906
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
907
        image_processor: Optional[Qwen2VLImageProcessor],
908
909
910
911
912
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
913
            image_processor=image_processor,
914
915
916
        )
        return num_video_tokens

917
    def get_image_size_with_most_features(self) -> ImageSize:
918
919
920
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
921
            image_processor=None,
922
923
924
        )
        return max_image_size

925
926
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
927

928
        return self.get_num_image_tokens(
929
930
            image_width=target_width,
            image_height=target_height,
931
            image_processor=None,
932
        )
933
934

    def _get_max_video_frames(self, max_tokens: int) -> int:
935
        target_width, target_height = self.get_image_size_with_most_features()
936

937
938
939
940
        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
941
            next_max_tokens = self.get_num_video_tokens(
942
943
944
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
945
                image_processor=None,
946
            )
947

948
            if next_max_tokens > max_tokens:
949
950
951
952
953
954
                break

            num_frames = next_num_frames

        return num_frames

955
956
957
958
959
960
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_videos = mm_counts.get("video", 0)
961

962
        max_total_frames = self._get_max_video_frames(seq_len)
963
964
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
965

966
        return max(max_frames_per_video, 1)
967

968
969
970
971
972
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
973
        target_width, target_height = self.get_image_size_with_most_features()
974

975
        return self.get_num_video_tokens(
976
977
            image_width=target_width,
            image_height=target_height,
978
979
            num_frames=self.get_num_frames_with_most_features(
                seq_len, mm_counts),
980
            image_processor=None,
981
982
        )

983
984
985

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):

986
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
987
988
989
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

990
        hf_processor = self.info.get_hf_processor()
991
992
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
993

994
995
996
997
998
999
1000
1001
1002
1003
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1004
1005
1006
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = \
1007
            self.info.get_num_frames_with_most_features(seq_len, mm_counts)
1008

1009
        return {
1010
1011
1012
1013
1014
1015
1016
1017
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
1018
                num_frames=target_num_frames,
1019
1020
                num_videos=num_videos,
            )
1021
1022
        }

1023

1024
1025
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
                                 ):
1026

1027
    def _get_data_parser(self) -> MultiModalDataParser:
1028
1029
        return Qwen2VLMultiModalDataParser(
            self.info.get_hf_config().vision_config.spatial_merge_size)
1030

1031
    def _get_prompt_updates(
1032
1033
        self,
        mm_items: MultiModalDataItems,
1034
        hf_processor_mm_kwargs: Mapping[str, Any],
1035
        out_mm_kwargs: MultiModalKwargsItems,
1036
    ) -> Sequence[PromptUpdate]:
1037
1038
1039
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)
1040
1041
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1042
1043

        placeholder = {
1044
1045
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1046
        }
1047

1048
1049
1050
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1051
1052
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1053
1054
            assert isinstance(grid_thw, torch.Tensor)

1055
1056
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1057
1058
1059
1060

        return [
            PromptReplacement(
                modality=modality,
1061
                target=[placeholder[modality]],
1062
1063
1064
1065
                replacement=partial(get_replacement_qwen2vl,
                                    modality=modality),
            ) for modality in ("image", "video")
        ]
1066

1067
1068
1069
1070
1071
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1072
1073
1074
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size)(
                hf_inputs)
1075

1076

1077
1078
1079
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
                                        info=Qwen2VLProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
1080
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1081
1082
                                      SupportsLoRA, SupportsPP):

1083
    # To ensure correct weight loading and mapping.
1084
1085
1086
1087
1088
1089
1090
1091
1092
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })
1093

1094
1095
1096
1097
1098
1099
1100
1101
1102
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1103
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1104
        super().__init__()
1105
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1106
1107
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1108
1109
1110
1111

        self.config = config
        self.multimodal_config = multimodal_config

1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        if multimodal_config.get_limit_per_prompt("image") or \
            multimodal_config.get_limit_per_prompt("video"):
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=self._maybe_ignore_quant_config(quant_config),
                prefix=maybe_prefix(prefix, "visual"),
            )
        else:
            self.visual = None
1122

1123
1124
1125
1126
1127
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1128

1129
1130
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
1131

1132
1133
1134
1135
1136
1137
1138
1139
    def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
        # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
        # seems to avoid vision encoder sections for some models.
        # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
        if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
            return None
        return quant_config

1140
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
1141
1142
1143
1144
1145
1146
1147
1148
1149
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1150
1151
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
1152
1153
1154
1155
1156
1157
1158
            return torch.concat(list(mm_input))
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
1159
        image_embeds = kwargs.pop("image_embeds", None)
1160
1161
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1162
        if pixel_values is None and image_embeds is None:
1163
1164
            return None

1165
1166
1167
1168
1169
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1170

1171
            return Qwen2VLImagePixelInputs(type="pixel_values",
1172
                                           pixel_values=pixel_values,
1173
1174
1175
                                           image_grid_thw=image_grid_thw)

        if image_embeds is not None:
1176
1177
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
1178
1179
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1180

1181
            return Qwen2VLImageEmbeddingInputs(type="image_embeds",
1182
1183
                                               image_embeds=image_embeds,
                                               image_grid_thw=image_grid_thw)
1184
1185
1186
1187

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1188
        video_embeds = kwargs.pop("video_embeds", None)
1189
1190
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1191
        if pixel_values_videos is None and video_embeds is None:
1192
1193
            return None

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
                                               video_embeds=video_embeds,
                                               video_grid_thw=video_grid_thw)
1215

1216
1217
1218
1219
1220
    def _process_image_input(
            self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:

        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1221
        grid_thw_list = grid_thw.tolist()
1222

1223
        if image_input["type"] == "image_embeds":
1224
            image_embeds = image_input["image_embeds"]
1225
        else:
1226
            pixel_values = image_input["pixel_values"]
1227
1228
1229
1230
            image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1231
1232
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
1233

1234
        return image_embeds.split(sizes)
1235
1236
1237

    def _process_video_input(
            self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
1238

1239
1240
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1241
        grid_thw_list = grid_thw.tolist()
1242

1243
        if video_input["type"] == "video_embeds":
1244
            video_embeds = video_input["video_embeds"]
1245
        else:
1246
            pixel_values_videos = video_input["pixel_values_videos"]
1247
            video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1248

1249
1250
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1251
1252
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
1253

1254
        return video_embeds.split(sizes)
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key in ("pixel_values_videos",
                             "video_embeds") and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities
1272

1273
1274
1275
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1276
1277
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
1278

1279
1280
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1281
            return []
1282

1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
                multimodal_embeddings += video_embeddings
1298
1299
1300
1301
1302
1303

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
1304
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1305
    ) -> torch.Tensor:
1306
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1307
1308
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
1309
1310
1311
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [self.config.image_token_id, self.config.video_token_id])
1312
1313
        return inputs_embeds

1314
1315
1316
    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
1317
1318
        image_input: Optional[Qwen2VLImagePixelInputs] = None,
        video_input: Optional[Qwen2VLVideoPixelInputs] = None,
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=self.config.image_token_id,
            )

        if video_input is not None:
            video_embeds = self._process_video_input(video_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                video_embeds,
                placeholder_token_id=self.config.video_token_id,
            )
        return inputs_embeds

1340
1341
1342
1343
1344
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1345
        inputs_embeds: Optional[torch.Tensor] = None,
1346
        **kwargs: object,
1347
    ) -> Union[torch.Tensor, IntermediateTensors]:
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
            pixel_values: Pixel values to be fed to a model.
                `None` if no images are passed.
            image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
                `None` if no images are passed.
            pixel_values_videos: Pixel values of videos to be fed to a model.
                `None` if no videos are passed.
            video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
                `None` if no videos are passed.
        """
1367

1368
        if intermediate_tensors is not None:
1369
            inputs_embeds = None
1370

1371
1372
1373
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
1374
        elif inputs_embeds is None:
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
            image_input = self._parse_and_validate_image_input(**kwargs)
            video_input = self._parse_and_validate_video_input(**kwargs)

            if image_input is None and video_input is None:
                inputs_embeds = None
            else:
                if uses_mrope(self.config):
                    assert positions.ndim == 2 and positions.size(0) == 3, (
                        "multimodal section rotary embedding requires "
                        f"(3, seq_len) positions, but got {positions.size()}")
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    video_input=video_input)
                input_ids = None
1390

1391
        hidden_states = self.language_model.model(
1392
1393
            input_ids=input_ids,
            positions=positions,
1394
            intermediate_tensors=intermediate_tensors,
1395
1396
1397
1398
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1399
1400
1401
1402
1403
1404
1405
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
1406

1407
1408
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1409

1410
1411
1412
1413
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1414
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1415
1416
1417
1418
1419
1420
1421

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1422
1423
1424
            connector="visual.merger.",
            tower_model="visual.",
        )
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):

    def __init__(
        self,
        size: Optional[dict[str, int]] = None,
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
                "longest_edge": size["max_pixels"]
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):

    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1458
1459
1460
1461
1462
1463
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
            **kwargs)
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):

    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
        return Tarsier2ImageProcessor(
            **self.ctx.get_hf_image_processor_config())


@MULTIMODAL_REGISTRY.register_processor(Tarsier2MultiModalProcessor,
                                        info=Tarsier2ProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "vision_tower.": "visual.",
    })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:

1508
1509
1510
1511
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1512
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)