qwen2_vl.py 54.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27
from collections.abc import Iterable, Mapping, Sequence
28
from functools import partial
29
from typing import Any, Callable, Literal, Optional, TypedDict, Union
30
31
32
33
34

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
35
from transformers import AutoConfig, BatchFeature
36
37
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
                                          Qwen2VLProcessor)
38
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
    Qwen2VLConfig, Qwen2VLVisionConfig)
40
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
41
42
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
    Qwen2VLVideoProcessor)
43

44
from vllm.config import VllmConfig
45
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
46
47
48
49
50
51
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
52
53
54
55
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
    GPTQMarlinConfig)
56
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
57
from vllm.model_executor.models.module_mapping import MultiModelKeys
58
from vllm.multimodal import MULTIMODAL_REGISTRY
59
from vllm.multimodal.inputs import (ImageItem, ModalityData,
60
61
                                    MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, VideoItem)
62
63
64
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
                                   ModalityDataItems, MultiModalDataItems,
                                   MultiModalDataParser)
65
from vllm.multimodal.processing import (BaseMultiModalProcessor,
66
67
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
68
from vllm.multimodal.profiling import BaseDummyInputsBuilder
69
from vllm.platforms import _Backend, current_platform
70
from vllm.sequence import IntermediateTensors
71
from vllm.transformers_utils.config import uses_mrope
72
from vllm.transformers_utils.tokenizer import AnyTokenizer
73

74
75
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP)
76
from .utils import (AutoWeightsLoader, WeightsMapper,
77
78
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
79
from .vision import get_vit_attn_backend
80

81
82
logger = init_logger(__name__)

83
84
85
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

86
87
88
# === Vision Inputs === #


89
90
class Qwen2VLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
91
    pixel_values: torch.Tensor
92
    """Shape:
93
94
95
96
97
98
99
100
101
    `(num_patches, num_channels * patch_size * patch_size)`
    """

    image_grid_thw: torch.Tensor
    """Shape: `(num_images, 3)`
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


102
103
class Qwen2VLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
104
105
    image_embeds: torch.Tensor
    """Supported types:
106
    - list[`torch.Tensor`]: A list of tensors holding all images' features.
107
108
109
110
111
112
113
114
115
116
117
118
119
        Each tensor holds an image's features.
    - `torch.Tensor`: A tensor holding all images' features
        (concatenation of all images' feature tensors).
    
    Tensor shape: `(num_image_features, hidden_size)`
    - `num_image_features` varies based on
        the number and resolution of the images.
    - `hidden_size` must match the hidden size of language model backbone.
    """

    image_grid_thw: torch.Tensor
    """Shape: `(num_images, 3)`
    This should be in `(grid_t, grid_h, grid_w)` format.
120
121
122
123
124
125
126
    """


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
                           Qwen2VLImageEmbeddingInputs]


127
128
class Qwen2VLVideoPixelInputs(TypedDict):
    type: Literal["pixel_values_videos"]
129
    pixel_values_videos: torch.Tensor
130
131
    """Shape:
    `(num_patches,
132
133
134
135
136
      num_channels * temporal_patch_size * patch_size * patch_size)`
    """

    video_grid_thw: torch.Tensor
    """Shape: `(num_videos, 3)`
137

138
139
140
141
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


142
143
144
145
class Qwen2VLVideoEmbeddingInputs(TypedDict):
    type: Literal["video_embeds"]
    video_embeds: torch.Tensor
    """Supported types:
146
    - list[`torch.Tensor`]: A list of tensors holding all videos' features.
147
148
        Each tensor holds an video's features.
    - `torch.Tensor`: A tensor holding all videos' features
149
        (concatenation of all videos' feature tensors).
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    
    Tensor shape: `(num_image_features, hidden_size)`
    - `num_image_features` varies based on 
        the number and resolution of the videos.
    - `hidden_size` must match the hidden size of language model backbone.
    """

    video_grid_thw: torch.Tensor
    """Shape: `(num_videos, 3)`
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
                           Qwen2VLVideoEmbeddingInputs]

166
167
168
169
170
171
172
173
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
174
        hidden_features: int,
175
        act_layer: type[nn.Module] = QuickGELU,
176
        quant_config: Optional[QuantizationConfig] = None,
177
        prefix: str = "",
178
179
180
181
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(in_features,
                                        hidden_features,
182
183
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
184
185
186
        self.act = act_layer()
        self.fc2 = RowParallelLinear(hidden_features,
                                     in_features,
187
188
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)


def apply_rotary_emb_torch(x: torch.Tensor,
                           cos: torch.Tensor,
                           sin: torch.Tensor,
                           interleaved: bool = False) -> torch.Tensor:
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )


def apply_rotary_pos_emb_vision(t: torch.Tensor,
234
                                freqs: torch.Tensor) -> torch.Tensor:
235
236
237
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
燃's avatar
committed
238
    apply_rotary_emb = apply_rotary_emb_torch
239
240
    if current_platform.is_cuda():
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
燃's avatar
committed
241
    output = apply_rotary_emb(t_, cos, sin).type_as(t)
242
243
244
245
246
247
248
    return output


class Qwen2VisionAttention(nn.Module):

    def __init__(
        self,
249
250
251
        embed_dim: int,
        num_heads: int,
        projection_size: int,
252
        quant_config: Optional[QuantizationConfig] = None,
253
        prefix: str = "",
254
255
256
257
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
        world_size = parallel_state.get_tensor_model_parallel_world_size()
258
259
        self.tp_size = world_size
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
260
261
262
263
264
265
266
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
            num_heads, world_size)

        self.qkv = ColumnParallelLinear(input_size=embed_dim,
                                        output_size=3 * projection_size,
267
268
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.qkv")
269
270
        self.proj = RowParallelLinear(input_size=projection_size,
                                      output_size=embed_dim,
271
272
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.proj")
273
274

        # Detect attention implementation.
275
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
276
277
278
279
280
        if self.attn_backend not in {
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
        }:
            raise RuntimeError(
                f"Qwen2-VL does not support {self.attn_backend} backend now.")
281

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(dist_utils.split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
                     self.hidden_size_per_attention_head)
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

305
    def forward(
306
307
308
309
310
311
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
312
313
    ) -> torch.Tensor:

314
315
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
316

317
318
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
319
320
        batch_size = q.shape[1]

321
322
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
323
324
325
326
        if rotary_pos_emb is not None:
            q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
            k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

327
        if self.attn_backend == _Backend.FLASH_ATTN:
328
329
330
331
            # from vllm_flash_attn.flash_attn_interface import (
            #   flash_attn_varlen_func)
            from flash_attn import flash_attn_varlen_func

332
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
333
334
335
336
337
338
339
340
341
342
343
344
345
346

            output = flash_attn_varlen_func(q,
                                            k,
                                            v,
                                            cu_seqlens_q=cu_seqlens,
                                            cu_seqlens_k=cu_seqlens,
                                            max_seqlen_q=max_seqlen,
                                            max_seqlen_k=max_seqlen,
                                            dropout_p=0,
                                            causal=False)

            context_layer = rearrange(output,
                                      "(b s) ... -> b s ...",
                                      b=batch_size)
347
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
348
349
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
350
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                                 for x in [q_i, k_i, v_i])
                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
365
        elif self.attn_backend == _Backend.XFORMERS:
366
367
368
369
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
370
371
                                                       kv_seqlen=None,
                                                       device=q.device)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
        context_layer = rearrange(context_layer,
                                  "b s h d -> s b (h d)").contiguous()

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
389
        act_layer: type[nn.Module] = QuickGELU,
390
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
391
        quant_config: Optional[QuantizationConfig] = None,
392
        prefix: str = "",
393
394
395
396
397
398
399
400
401
402
403
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.attn = Qwen2VisionAttention(embed_dim=dim,
                                         num_heads=num_heads,
                                         projection_size=dim,
404
405
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.attn")
406
407
408
        self.mlp = Qwen2VisionMLP(dim,
                                  mlp_hidden_dim,
                                  act_layer=act_layer,
409
410
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.mlp")
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

428
429
430
431
432
433
434
435
436
437
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
438
        in_channels: int = 3,
439
440
441
442
443
444
445
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

446
447
        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(in_channels,
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
467
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
468
469
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
470
        prefix: str = "",
471
472
473
474
475
476
477
478
479
480
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
        self.mlp = nn.ModuleList([
            ColumnParallelLinear(self.hidden_size,
                                 self.hidden_size,
                                 bias=True,
481
482
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mlp.0"),
483
484
485
486
            nn.GELU(),
            RowParallelLinear(self.hidden_size,
                              d_model,
                              bias=True,
487
488
                              quant_config=quant_config,
                              prefix=f"{prefix}.mlp.2"),
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
                                                / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
539
        prefix: str = "",
540
541
542
    ) -> None:
        super().__init__()

543
544
545
546
547
548
549
550
551
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
552
553

        self.spatial_merge_size = spatial_merge_size
554
555
        self.num_heads = num_heads
        self.embed_dim = embed_dim
556
557
558
559

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
560
            in_channels=in_channels,
561
562
563
564
565
566
567
568
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([
569
570
571
572
573
574
575
            Qwen2VisionBlock(dim=embed_dim,
                             num_heads=num_heads,
                             mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer,
                             quant_config=quant_config,
                             prefix=f"{prefix}.blocks.{layer_idx}")
            for layer_idx in range(depth)
576
577
578
579
580
581
        ])
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
582
            prefix=f"{prefix}.merger",
583
        )
584
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
585
586
587

    @property
    def dtype(self) -> torch.dtype:
588
        return self.patch_embed.proj.weight.dtype
589
590
591

    @property
    def device(self) -> torch.device:
592
        return self.patch_embed.proj.weight.device
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

619
620
621
622
623
624
625
626
627
628
    def compute_attn_mask_seqlen(
            self, cu_seqlens: torch.Tensor
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
        if self.attn_backend == _Backend.FLASH_ATTN:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                             grid_thw[:, 0]).cumsum(
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
649

650
651
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
652
        for blk in self.blocks:
653
654
655
656
657
658
659
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
660
661
662

        # adapter
        x = self.merger(x)
663

664
665
        return x

666
667
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
668
669
670
671
672
673
674
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
675
        loaded_params: set[str] = set()
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

695

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
    image_grid_sizes = image_grid_thw.prod(-1)

    video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
    video_grid_sizes = video_grid_thw.prod(-1)

    return dict(
        pixel_values=MultiModalFieldConfig.flat_from_sizes(
            "image", image_grid_sizes),
        image_embeds=MultiModalFieldConfig.flat_from_sizes(
            "image", image_grid_sizes),
        image_grid_thw=MultiModalFieldConfig.batched("image"),
        pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
            "video", video_grid_sizes),
        video_embeds=MultiModalFieldConfig.flat_from_sizes(
            "video", video_grid_sizes),
        video_grid_thw=MultiModalFieldConfig.batched("video"),
    )
715

716

Roger Wang's avatar
Roger Wang committed
717
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
718
719
720
721

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
722
    ) -> Optional[ModalityDataItems[Any, Any]]:
723
        if isinstance(data, dict):
724
725
726
727
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
728
                fields_factory=_qwen2vl_field_config,
729
            )
730
731
732
733

        return super()._parse_image_data(data)

    def _parse_video_data(
734
        self,
735
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
736
    ) -> Optional[ModalityDataItems[Any, Any]]:
737
        if isinstance(data, dict):
738
739
740
741
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
742
                fields_factory=_qwen2vl_field_config,
743
            )
744
745
746
747

        return super()._parse_video_data(data)


748
class Qwen2VLProcessingInfo(BaseProcessingInfo):
749

750
    def get_hf_config(self):
751
752
        return self.ctx.get_hf_config(Qwen2VLConfig)

753
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
754
755
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
756
            use_fast=kwargs.pop("use_fast", True),
757
758
759
            **kwargs,
        )

760
761
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
762

763
764
765
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

766
767
768
769
770
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
771
772
773
774
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

775
776
777
778
779
780
781
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
782
        image_processor: Optional[Qwen2VLImageProcessor],
783
    ) -> tuple[ImageSize, int]:
784
785
786
787
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
788
        vision_config = hf_config.vision_config
789
790
791
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
792

793
794
795
796
797
798
799
800
801
802
803
804
805
806
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

807
808
809
810
811
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
812
813
814
815
816
817
818
819
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

820
    def get_num_image_tokens(
821
822
823
824
        self,
        *,
        image_width: int,
        image_height: int,
825
        image_processor: Optional[Qwen2VLImageProcessor],
826
827
828
829
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
830
            image_processor=image_processor,
831
832
833
        )
        return num_image_tokens

834
    def get_num_video_tokens(
835
836
837
838
839
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
840
        image_processor: Optional[Qwen2VLImageProcessor],
841
842
843
844
845
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
846
            image_processor=image_processor,
847
848
849
        )
        return num_video_tokens

850
    def get_image_size_with_most_features(self) -> ImageSize:
851
852
853
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
854
            image_processor=None,
855
856
857
        )
        return max_image_size

858
859
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
860

861
        return self.get_num_image_tokens(
862
863
            image_width=target_width,
            image_height=target_height,
864
            image_processor=None,
865
        )
866
867

    def _get_max_video_frames(self, max_tokens: int) -> int:
868
        target_width, target_height = self.get_image_size_with_most_features()
869

870
871
872
873
        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
874
            next_max_tokens = self.get_num_video_tokens(
875
876
877
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
878
                image_processor=None,
879
            )
880

881
            if next_max_tokens > max_tokens:
882
883
884
885
886
887
                break

            num_frames = next_num_frames

        return num_frames

888
889
890
891
892
893
894
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
895

896
        max_image_tokens = self.get_max_image_tokens() * max_images
897
898
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
899
900
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
901

902
        return max(max_frames_per_video, 1)
903

904
905
906
907
908
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
909
        target_width, target_height = self.get_image_size_with_most_features()
910

911
        return self.get_num_video_tokens(
912
913
            image_width=target_width,
            image_height=target_height,
914
915
            num_frames=self.get_num_frames_with_most_features(
                seq_len, mm_counts),
916
            image_processor=None,
917
918
        )

919
920
921

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):

922
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
923
924
925
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

926
        hf_processor = self.info.get_hf_processor()
927
928
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
929

930
931
932
933
934
935
936
937
938
939
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

940
941
942
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = \
943
            self.info.get_num_frames_with_most_features(seq_len, mm_counts)
944

945
        return {
946
947
948
949
950
951
952
953
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
954
                num_frames=target_num_frames,
955
956
                num_videos=num_videos,
            )
957
958
        }

959

960
961
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
                                 ):
962

963
    def _get_data_parser(self) -> MultiModalDataParser:
Roger Wang's avatar
Roger Wang committed
964
        return Qwen2VLMultiModalDataParser()
965

966
    def _get_prompt_updates(
967
968
        self,
        mm_items: MultiModalDataItems,
969
        hf_processor_mm_kwargs: Mapping[str, Any],
970
        out_mm_kwargs: MultiModalKwargs,
971
    ) -> Sequence[PromptUpdate]:
972
973
974
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)
975
976
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
977
978

        placeholder = {
979
980
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
981
        }
982

983
984
985
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
986
987
988
            grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
            assert isinstance(grid_thw, torch.Tensor)

989
990
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
991
992
993
994

        return [
            PromptReplacement(
                modality=modality,
995
                target=[placeholder[modality]],
996
997
998
999
                replacement=partial(get_replacement_qwen2vl,
                                    modality=modality),
            ) for modality in ("image", "video")
        ]
1000

1001
1002
1003
1004
1005
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1006
        return _qwen2vl_field_config(hf_inputs)
1007

1008

1009
1010
1011
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
                                        info=Qwen2VLProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
1012
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1013
1014
                                      SupportsLoRA, SupportsPP):

1015
    # To ensure correct weight loading and mapping.
1016
1017
1018
1019
1020
1021
1022
1023
1024
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })
1025

1026
1027
1028
1029
1030
1031
1032
1033
1034
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1035
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1036
        super().__init__()
1037
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1038
1039
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1040
1041
1042
1043
1044
1045
1046

        self.config = config
        self.multimodal_config = multimodal_config

        self.visual = Qwen2VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1047
            quant_config=self._maybe_ignore_quant_config(quant_config),
1048
            prefix=maybe_prefix(prefix, "visual"),
1049
1050
        )

1051
1052
1053
1054
1055
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1056

1057
1058
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
1059

1060
1061
1062
1063
1064
1065
1066
1067
    def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
        # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
        # seems to avoid vision encoder sections for some models.
        # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
        if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
            return None
        return quant_config

1068
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
1069
1070
1071
1072
1073
1074
1075
1076
1077
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1078
1079
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
1080
1081
1082
1083
1084
1085
1086
            return torch.concat(list(mm_input))
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
1087
        image_embeds = kwargs.pop("image_embeds", None)
1088
1089
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1090
        if pixel_values is None and image_embeds is None:
1091
1092
            return None

1093
1094
1095
1096
1097
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1098

1099
1100
1101
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image pixel values. "
                                 f"Got type: {type(pixel_values)}")
1102

1103
            return Qwen2VLImagePixelInputs(type="pixel_values",
1104
                                           pixel_values=pixel_values,
1105
1106
1107
                                           image_grid_thw=image_grid_thw)

        if image_embeds is not None:
1108
1109
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
1110
1111
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1112

1113
1114
1115
1116
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
            return Qwen2VLImageEmbeddingInputs(type="image_embeds",
1117
1118
                                               image_embeds=image_embeds,
                                               image_grid_thw=image_grid_thw)
1119
1120
1121
1122

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1123
        video_embeds = kwargs.pop("video_embeds", None)
1124
1125
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1126
        if pixel_values_videos is None and video_embeds is None:
1127
1128
            return None

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            if not isinstance(video_embeds, torch.Tensor):
                raise ValueError("Incorrect type of video embeddings. "
                                 f"Got type: {type(video_embeds)}")
            return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
                                               video_embeds=video_embeds,
                                               video_grid_thw=video_grid_thw)
1153

1154
1155
1156
1157
1158
1159
    def _process_image_input(
            self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:

        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1160
        if image_input["type"] == "image_embeds":
1161
            image_embeds = image_input["image_embeds"]
1162
        else:
1163
            pixel_values = image_input["pixel_values"]
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
            image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size

        return image_embeds.split(sizes.tolist())

    def _process_video_input(
            self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
1174

1175
1176
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1177

1178
        if video_input["type"] == "video_embeds":
1179
            video_embeds = video_input["video_embeds"]
1180
        else:
1181
            pixel_values_videos = video_input["pixel_values_videos"]
1182
            video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
1183

1184
1185
1186
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size
1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
        return video_embeds.split(sizes.tolist())

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key in ("pixel_values_videos",
                             "video_embeds") and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities
1206

1207
1208
1209
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1210
1211
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
1212

1213
1214
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1215
            return []
1216
1217
            return None

1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
                multimodal_embeddings += video_embeddings
1233
1234
1235
1236
1237
1238

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
1239
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1240
    ) -> torch.Tensor:
1241
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1242
1243
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
1244
1245
1246
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [self.config.image_token_id, self.config.video_token_id])
1247
1248
        return inputs_embeds

1249
1250
1251
    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
1252
1253
        image_input: Optional[Qwen2VLImagePixelInputs] = None,
        video_input: Optional[Qwen2VLVideoPixelInputs] = None,
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=self.config.image_token_id,
            )

        if video_input is not None:
            video_embeds = self._process_video_input(video_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                video_embeds,
                placeholder_token_id=self.config.video_token_id,
            )
        return inputs_embeds

1275
1276
1277
1278
1279
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1280
        inputs_embeds: Optional[torch.Tensor] = None,
1281
        **kwargs: object,
1282
    ) -> Union[torch.Tensor, IntermediateTensors]:
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
            pixel_values: Pixel values to be fed to a model.
                `None` if no images are passed.
            image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
                `None` if no images are passed.
            pixel_values_videos: Pixel values of videos to be fed to a model.
                `None` if no videos are passed.
            video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
                `None` if no videos are passed.
        """
1302

1303
        if intermediate_tensors is not None:
1304
            inputs_embeds = None
1305

1306
1307
1308
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
1309
        elif inputs_embeds is None:
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
            image_input = self._parse_and_validate_image_input(**kwargs)
            video_input = self._parse_and_validate_video_input(**kwargs)

            if image_input is None and video_input is None:
                inputs_embeds = None
            else:
                if uses_mrope(self.config):
                    assert positions.ndim == 2 and positions.size(0) == 3, (
                        "multimodal section rotary embedding requires "
                        f"(3, seq_len) positions, but got {positions.size()}")
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    video_input=video_input)
                input_ids = None
1325

1326
        hidden_states = self.language_model.model(
1327
1328
            input_ids=input_ids,
            positions=positions,
1329
            intermediate_tensors=intermediate_tensors,
1330
1331
1332
1333
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1334
1335
1336
1337
1338
1339
1340
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
1341

1342
1343
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1344
1345

        loader = AutoWeightsLoader(self)
1346
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1347
1348
1349
1350
1351
1352
1353

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1354
1355
1356
            connector="visual.merger.",
            tower_model="visual.",
        )
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):

    def __init__(
        self,
        size: Optional[dict[str, int]] = None,
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
                "longest_edge": size["max_pixels"]
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):

    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
        super().__init__(image_processor=self.image_processor,
                         tokenizer=tokenizer,
                         video_processor=Qwen2VLVideoProcessor(),
                         chat_template=None,
                         **kwargs)


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):

    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
        return Tarsier2ImageProcessor(
            **self.ctx.get_hf_image_processor_config())


@MULTIMODAL_REGISTRY.register_processor(Tarsier2MultiModalProcessor,
                                        info=Tarsier2ProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "vision_tower.": "visual.",
    })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:

        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)