qwen2_vl.py 46.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
24
from functools import cached_property, partial
25
26
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
                    Set, Tuple, Type, TypedDict, Union)
27
28
29
30
31

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
32
33
34
from transformers import BatchFeature
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
                                          Qwen2VLProcessor)
35
36
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
    Qwen2VLConfig, Qwen2VLVisionConfig)
37
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
38
39

from vllm.attention import AttentionMetadata
40
from vllm.config import VllmConfig
41
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
42
43
44
45
46
47
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
48
49
50
51
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
    GPTQMarlinConfig)
Joe Runde's avatar
Joe Runde committed
52
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
53
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
54
from vllm.model_executor.models.module_mapping import MultiModelKeys
55
from vllm.multimodal import MULTIMODAL_REGISTRY
56
from vllm.multimodal.inputs import (ImageItem, ModalityData,
57
                                    MultiModalFieldConfig, MultiModalKwargs,
58
59
                                    NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
60
from vllm.multimodal.processing import (BaseMultiModalProcessor,
61
62
                                        MultiModalDataItems, ProcessorInputs,
                                        PromptReplacement)
63
from vllm.platforms import _Backend
64
from vllm.sequence import IntermediateTensors
65
from vllm.transformers_utils.config import uses_mrope
66

67
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
68
69
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
                    init_vllm_registered_model, maybe_prefix)
70

71
72
73
74
75
logger = init_logger(__name__)

# === Vision Inputs === #


76
77
class Qwen2VLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
78
    pixel_values: torch.Tensor
79
    """Shape:
80
81
82
83
84
85
86
87
88
    `(num_patches, num_channels * patch_size * patch_size)`
    """

    image_grid_thw: torch.Tensor
    """Shape: `(num_images, 3)`
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


89
90
class Qwen2VLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    image_embeds: torch.Tensor
    """Supported types:
    - List[`torch.Tensor`]: A list of tensors holding all images' features.
        Each tensor holds an image's features.
    - `torch.Tensor`: A tensor holding all images' features
        (concatenation of all images' feature tensors).
    
    Tensor shape: `(num_image_features, hidden_size)`
    - `num_image_features` varies based on
        the number and resolution of the images.
    - `hidden_size` must match the hidden size of language model backbone.
    """

    image_grid_thw: torch.Tensor
    """Shape: `(num_images, 3)`
    This should be in `(grid_t, grid_h, grid_w)` format.
107
108
109
110
111
112
113
    """


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
                           Qwen2VLImageEmbeddingInputs]


114
115
class Qwen2VLVideoPixelInputs(TypedDict):
    type: Literal["pixel_values_videos"]
116
    pixel_values_videos: torch.Tensor
117
118
    """Shape:
    `(num_patches,
119
120
121
122
123
      num_channels * temporal_patch_size * patch_size * patch_size)`
    """

    video_grid_thw: torch.Tensor
    """Shape: `(num_videos, 3)`
124

125
126
127
128
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class Qwen2VLVideoEmbeddingInputs(TypedDict):
    type: Literal["video_embeds"]
    video_embeds: torch.Tensor
    """Supported types:
    - List[`torch.Tensor`]: A list of tensors holding all videos' features.
        Each tensor holds an video's features.
    - `torch.Tensor`: A tensor holding all videos' features
      (concatenation of all videos' feature tensors).
    
    Tensor shape: `(num_image_features, hidden_size)`
    - `num_image_features` varies based on 
        the number and resolution of the videos.
    - `hidden_size` must match the hidden size of language model backbone.
    """

    video_grid_thw: torch.Tensor
    """Shape: `(num_videos, 3)`
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
                           Qwen2VLVideoEmbeddingInputs]

153
154
155
156
157
158
159
160
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
161
        hidden_features: int,
162
163
        act_layer: Type[nn.Module] = QuickGELU,
        quant_config: Optional[QuantizationConfig] = None,
164
        prefix: str = "",
165
166
167
168
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(in_features,
                                        hidden_features,
169
170
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
171
172
173
        self.act = act_layer()
        self.fc2 = RowParallelLinear(hidden_features,
                                     in_features,
174
175
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)


def apply_rotary_emb_torch(x: torch.Tensor,
                           cos: torch.Tensor,
                           sin: torch.Tensor,
                           interleaved: bool = False) -> torch.Tensor:
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )


def apply_rotary_pos_emb_vision(t: torch.Tensor,
                                freqs: torch.Tensor) -> torch.Tensor:
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
    output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
    return output


class Qwen2VisionAttention(nn.Module):

    def __init__(
        self,
233
234
235
        embed_dim: int,
        num_heads: int,
        projection_size: int,
236
        quant_config: Optional[QuantizationConfig] = None,
237
        prefix: str = "",
238
239
240
241
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
        world_size = parallel_state.get_tensor_model_parallel_world_size()
242
243
        self.tp_size = world_size
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
244
245
246
247
248
249
250
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
            num_heads, world_size)

        self.qkv = ColumnParallelLinear(input_size=embed_dim,
                                        output_size=3 * projection_size,
251
252
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.qkv")
253
254
        self.proj = RowParallelLinear(input_size=projection_size,
                                      output_size=embed_dim,
255
256
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.proj")
257
258

        # Detect attention implementation.
259
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
260
261
262
263
264
        if self.attn_backend not in {
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
        }:
            raise RuntimeError(
                f"Qwen2-VL does not support {self.attn_backend} backend now.")
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(dist_utils.split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
                     self.hidden_size_per_attention_head)
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

289
290
291
292
    def forward(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
293
        rotary_pos_emb: torch.Tensor,
294
295
    ) -> torch.Tensor:

296
297
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
298

299
300
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
301
302
        batch_size = q.shape[1]

303
304
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
305
306
307
308
        if rotary_pos_emb is not None:
            q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
            k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

309
        if self.attn_backend == _Backend.FLASH_ATTN:
310
311
312
313
            # from vllm_flash_attn.flash_attn_interface import (
            #   flash_attn_varlen_func)
            from flash_attn import flash_attn_varlen_func

314
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
            output = flash_attn_varlen_func(q,
                                            k,
                                            v,
                                            cu_seqlens_q=cu_seqlens,
                                            cu_seqlens_k=cu_seqlens,
                                            max_seqlen_q=max_seqlen,
                                            max_seqlen_k=max_seqlen,
                                            dropout_p=0,
                                            causal=False)

            context_layer = rearrange(output,
                                      "(b s) ... -> b s ...",
                                      b=batch_size)
330
        elif self.attn_backend == _Backend.TORCH_SDPA:
331
            seq_length = q.size(1)
332
            q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
333
334
335
336
337
338
339
340
341
342
343
344
            attention_mask = torch.zeros([1, seq_length, seq_length],
                                         device=q.device,
                                         dtype=torch.bool)
            for i in range(1, len(cu_seqlens)):
                attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
                               cu_seqlens[i - 1]:cu_seqlens[i]] = True
            output = F.scaled_dot_product_attention(q,
                                                    k,
                                                    v,
                                                    attention_mask,
                                                    dropout_p=0.0)
            context_layer = rearrange(output, "b h s d -> b s h d ")
345
        elif self.attn_backend == _Backend.XFORMERS:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
                                                       kv_seqlen=None)

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
        context_layer = rearrange(context_layer,
                                  "b s h d -> s b (h d)").contiguous()

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
        act_layer: Type[nn.Module] = QuickGELU,
370
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
371
        quant_config: Optional[QuantizationConfig] = None,
372
        prefix: str = "",
373
374
375
376
377
378
379
380
381
382
383
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.attn = Qwen2VisionAttention(embed_dim=dim,
                                         num_heads=num_heads,
                                         projection_size=dim,
384
385
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.attn")
386
387
388
        self.mlp = Qwen2VisionMLP(dim,
                                  mlp_hidden_dim,
                                  act_layer=act_layer,
389
390
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.mlp")
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

    def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
                rotary_pos_emb: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x),
                          cu_seqlens=cu_seqlens,
                          rotary_pos_emb=rotary_pos_emb)
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
407
        in_channels: int = 3,
408
409
410
411
412
413
414
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

415
416
        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(in_channels,
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
436
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
437
438
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
439
        prefix: str = "",
440
441
442
443
444
445
446
447
448
449
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
        self.mlp = nn.ModuleList([
            ColumnParallelLinear(self.hidden_size,
                                 self.hidden_size,
                                 bias=True,
450
451
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mlp.0"),
452
453
454
455
            nn.GELU(),
            RowParallelLinear(self.hidden_size,
                              d_model,
                              bias=True,
456
457
                              quant_config=quant_config,
                              prefix=f"{prefix}.mlp.2"),
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
                                                / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
508
        prefix: str = "",
509
510
511
    ) -> None:
        super().__init__()

512
513
514
515
516
517
518
519
520
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
521
522

        self.spatial_merge_size = spatial_merge_size
523
524
        self.num_heads = num_heads
        self.embed_dim = embed_dim
525
526
527
528

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
529
            in_channels=in_channels,
530
531
532
533
534
535
536
537
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([
538
539
540
541
542
543
544
            Qwen2VisionBlock(dim=embed_dim,
                             num_heads=num_heads,
                             mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer,
                             quant_config=quant_config,
                             prefix=f"{prefix}.blocks.{layer_idx}")
            for layer_idx in range(depth)
545
546
547
548
549
550
        ])
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
551
            prefix=f"{prefix}.merger",
552
553
554
555
        )

    @property
    def dtype(self) -> torch.dtype:
556
        return self.patch_embed.proj.weight.dtype
557
558
559

    @property
    def device(self) -> torch.device:
560
        return self.patch_embed.proj.weight.device
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                             grid_thw[:, 0]).cumsum(
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
            x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

        # adapter
        x = self.merger(x)
        return x

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

643
644
645
646
647

# === Vision input helpers === #


def _get_vision_info(
648
    vision_config: Qwen2VLVisionConfig,
649
650
651
652
    height: int,
    width: int,
    min_pixels: int,
    max_pixels: int,
653
    *,
654
    do_resize: bool = True,
655
    modality: str = "image",
656
657
658
659
    mm_count: int = 1,
):
    """Get information (resized height / width and number of vision tokens)
    of input image / video frame."""
660
661
662
    patch_size = vision_config.patch_size
    merge_size = vision_config.spatial_merge_size
    temporal_patch_size = vision_config.temporal_patch_size
663
664
665
666
667

    if do_resize:
        resized_height, resized_width = smart_resize(
            height=height,
            width=width,
668
            factor=patch_size * merge_size,
669
670
671
672
673
674
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
    else:
        resized_height, resized_width = height, width

675
    if modality == "image":
676
        grid_t = mm_count
677
    elif modality == "video":
678
        grid_t = max(mm_count // temporal_patch_size, 1)
679
680
    else:
        raise ValueError(f"Modality {modality} is not supported")
681

682
683
    grid_h = resized_height // patch_size
    grid_w = resized_width // patch_size
684
    vision_tokens = grid_t * grid_h * grid_w
685
    llm_num_vision_tokens = vision_tokens // (merge_size**2)
686
687
688
689

    return resized_height, resized_width, llm_num_vision_tokens


690
691
692
693
def _get_image_processor(hf_processor: Qwen2VLProcessor):
    image_processor = hf_processor.image_processor  # type: ignore
    assert isinstance(image_processor, Qwen2VLImageProcessor)
    return image_processor
694
695


696
697
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
                                            dict[str, torch.Tensor]]):
698

699
    def __init__(self, data: dict, modality: str) -> None:
700
        super().__init__(data, modality)
701

702
703
704
705
706
707
        grid_thw = data[f"{modality}_grid_thw"]
        slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
        self._slices = [
            slice(slice_idxs[i], slice_idxs[i + 1])
            for i in range(len(grid_thw))
        ]
708

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
    def get_count(self) -> int:
        return len(self.data[f"{self.modality}_grid_thw"])

    def get(self, index: int) -> dict[str, torch.Tensor]:
        out = {}
        for k, v in self.data.items():
            if v != f"{self.modality}_grid_thw":
                v = v[self._slices[index]]

            out[k] = v

        return out

    def get_processor_data(self) -> Mapping[str, object]:
        return {}

    def get_passthrough_data(self) -> Mapping[str, object]:
        return self.data


class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):

    def __init__(self, data: dict) -> None:
        super().__init__(data, "image")


class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):

    def __init__(self, data: dict) -> None:
        super().__init__(data, "video")
739

740
741
742
743
744
745
746
747
748
749
750
751
752

class Qwen2MultiModalDataParser(MultiModalDataParser):

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return Qwen2EmbeddingItems(data, modality="image")

        return super()._parse_image_data(data)

    def _parse_video_data(
753
        self,
754
755
756
757
758
759
760
761
762
763
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return Qwen2EmbeddingItems(data, modality="video")

        return super()._parse_video_data(data)


class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):

764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

    def _get_max_mm_tokens(self, modality: str) -> int:
        hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
        vision_config = hf_config.vision_config

        hf_processor = self._get_hf_processor()
        image_processor = _get_image_processor(hf_processor)

        _, _, max_llm_image_tokens = _get_vision_info(
            vision_config,
            height=9999999,
            width=9999999,
            min_pixels=image_processor.min_pixels,
            max_pixels=image_processor.max_pixels,
            modality=modality,
        )
        return max_llm_image_tokens

    def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
        return {
            "image": self._get_max_mm_tokens("image"),
            "video": self._get_max_mm_tokens("video"),
        }

790
791
    def _get_data_parser(self) -> MultiModalDataParser:
        return Qwen2MultiModalDataParser()
792

793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    def _get_hf_processor(
        self,
        *,
        min_pixels: Optional[int] = None,
        max_pixels: Optional[int] = None,
    ) -> Qwen2VLProcessor:
        hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
        image_processor = _get_image_processor(hf_processor)

        if min_pixels:
            image_processor.min_pixels = min_pixels
        if max_pixels:
            image_processor.max_pixels = max_pixels
        if max_pixels or min_pixels:
            image_processor.size = {
                "min_pixels": image_processor.min_pixels,
                "max_pixels": image_processor.max_pixels,
            }

        return hf_processor

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
817
818
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
819
820
821
822
823
824
825
826
827
828
829
830
831
    ) -> list[PromptReplacement]:
        hf_processor = self._get_hf_processor()
        image_processor = _get_image_processor(hf_processor)

        # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
        # image_token and video_token registered
        placeholder = {
            "image": hf_processor.image_token,
            "video": hf_processor.video_token,
        }
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
832
833
834
            grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
            assert isinstance(grid_thw, torch.Tensor)

835
836
837
838
839
840
841
842
843
844
845
            num_tokens = grid_thw.prod() // merge_length
            return placeholder[modality] * num_tokens

        return [
            PromptReplacement(
                modality=modality,
                target=placeholder[modality],
                replacement=partial(get_replacement_qwen2vl,
                                    modality=modality),
            ) for modality in ("image", "video")
        ]
846

847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
        image_slices = [
            slice(image_slice_idxs[i], image_slice_idxs[i + 1])
            for i in range(len(image_grid_thw))
        ]

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
        video_slices = [
            slice(video_slice_idxs[i], video_slice_idxs[i + 1])
            for i in range(len(video_grid_thw))
        ]

        return dict(
            pixel_values=MultiModalFieldConfig.flat("image", image_slices),
            image_embeds=MultiModalFieldConfig.flat("image", image_slices),
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat(
                "video", video_slices),
            video_embeds=MultiModalFieldConfig.flat("video", video_slices),
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

876
877
878
879
880
881
882
    def _get_dummy_mm_inputs(
        self,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        hf_processor = self._get_hf_processor()
        image_processor = _get_image_processor(hf_processor)

883
        image_token: str = hf_processor.image_token
884
885
886
887
888
889
        resized_height, resized_width = smart_resize(
            height=9999999,
            width=9999999,
            factor=image_processor.patch_size * image_processor.merge_size,
            min_pixels=image_processor.min_pixels,
            max_pixels=image_processor.max_pixels,
890
        )
891
        num_images = mm_counts.get("image", 0)
892

893
894
895
896
897
898
        mm_data = {
            "image":
            self._get_dummy_images(width=resized_width,
                                   height=resized_height,
                                   num_images=num_images)
        }
899
900
901

        return ProcessorInputs(
            prompt_text=image_token * num_images,
902
            mm_data=mm_data,
903
        )
904
905


906
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
907
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
                                      SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
927
928
929
930
931
932
933
934
        # vision tower
        "qkv",
        "attn.proj",  # Distinguish patch_embed.proj
        "fc1",
        "fc2",
        # projector
        "mlp.0",
        "mlp.2"
935
936
937
    ]
    embedding_modules = {}
    embedding_padding_modules = []
938

939
940
941
942
943
944
945
946
947
948
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

949
950
951
952
953
    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "lm_head.": "language_model.lm_head.",
        "model.": "language_model.model.",
    })
954

955
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
956
        super().__init__()
957
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
958
959
960
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
961
962
963
964
965
966
967
968
969
        assert not cache_config.enable_prefix_caching, \
            "Qwen2-VL currently does not support prefix caching"

        self.config = config
        self.multimodal_config = multimodal_config

        self.visual = Qwen2VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
970
            quant_config=self._maybe_ignore_quant_config(quant_config),
971
            prefix=maybe_prefix(prefix, "visual"),
972
973
        )

974
975
976
977
978
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
979

980
981
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
982

983
984
985
986
    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler
987

988
        return get_sampler()
989

990
991
992
993
994
995
996
997
    def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
        # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
        # seems to avoid vision encoder sections for some models.
        # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
        if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
            return None
        return quant_config

998
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
999
1000
1001
1002
1003
1004
1005
1006
1007
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1008
1009
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
1010
1011
1012
1013
1014
1015
1016
            return torch.concat(list(mm_input))
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
1017
        image_embeds = kwargs.pop("image_embeds", None)
1018
1019
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1020
        if pixel_values is None and image_embeds is None:
1021
1022
            return None

1023
1024
1025
1026
1027
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1028

1029
1030
1031
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image pixel values. "
                                 f"Got type: {type(pixel_values)}")
1032

1033
            return Qwen2VLImagePixelInputs(type="pixel_values",
1034
                                           pixel_values=pixel_values,
1035
1036
1037
                                           image_grid_thw=image_grid_thw)

        if image_embeds is not None:
1038
1039
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
1040
1041
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1042

1043
1044
1045
1046
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
            return Qwen2VLImageEmbeddingInputs(type="image_embeds",
1047
1048
                                               image_embeds=image_embeds,
                                               image_grid_thw=image_grid_thw)
1049
1050
1051
1052

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1053
        video_embeds = kwargs.pop("video_embeds", None)
1054
1055
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1056
        if pixel_values_videos is None and video_embeds is None:
1057
1058
            return None

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            if not isinstance(video_embeds, torch.Tensor):
                raise ValueError("Incorrect type of video embeddings. "
                                 f"Got type: {type(video_embeds)}")
            return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
                                               video_embeds=video_embeds,
                                               video_grid_thw=video_grid_thw)
1083
1084
1085

    def _process_image_input(self,
                             image_input: Qwen2VLImageInputs) -> torch.Tensor:
1086
        if image_input["type"] == "image_embeds":
1087
            return image_input["image_embeds"].type(self.visual.dtype)
1088

1089
        pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1090
1091
1092
1093
1094
1095
        image_embeds = self.visual(pixel_values,
                                   grid_thw=image_input["image_grid_thw"])
        return image_embeds

    def _process_video_input(self,
                             video_input: Qwen2VLVideoInputs) -> torch.Tensor:
1096
1097
1098
        if video_input["type"] == "video_embeds":
            return video_input["video_embeds"].type(self.visual.dtype)

1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        pixel_values_videos = video_input["pixel_values_videos"].type(
            self.visual.dtype)
        video_embeds = self.visual(pixel_values_videos,
                                   grid_thw=video_input["video_grid_thw"])
        return video_embeds

    def _merge_multimodal_embeddings(
        self,
        input_ids: torch.Tensor,
        inputs_embeds: torch.Tensor,
        multimodal_embeddings: torch.Tensor,
        placeholder_token_id: int,
    ) -> torch.Tensor:
        mask = (input_ids == placeholder_token_id)
        inputs_embeds[mask, :] = multimodal_embeddings
        return inputs_embeds

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    def get_multimodal_embeddings(
            self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:

        image_input = self._parse_and_validate_image_input(**kwargs)
        video_input = self._parse_and_validate_video_input(**kwargs)
        if image_input is None and video_input is None:
            return None

        # We make a tuple of each embedding with its modality string. This is a
        # temporary workaround for models to handle mixed modalities when
        # get_multimodal_embeddings and get_input_embeddings are called
        # separately.
        # TODO(ywang96): Add support for mixed-modality inference for v1.
        multimodal_embeddings: List[Tuple[NestedTensors, str]] = []

        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            multimodal_embeddings.append((image_embeds, "image"))
        if video_input is not None:
            video_embeds = self._process_video_input(video_input)
            multimodal_embeddings.append((video_embeds, "video"))

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[List[Tuple[NestedTensors,
                                                   str]]] = None,
    ) -> torch.Tensor:
1146
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
        if multimodal_embeddings is not None:
            for embeddings, modality in multimodal_embeddings:
                if modality == "image":
                    inputs_embeds = self._merge_multimodal_embeddings(
                        input_ids,
                        inputs_embeds,
                        embeddings,
                        placeholder_token_id=self.config.image_token_id,
                    )
                if modality == "video":
                    inputs_embeds = self._merge_multimodal_embeddings(
                        input_ids,
                        inputs_embeds,
                        embeddings,
                        placeholder_token_id=self.config.video_token_id,
                    )
        return inputs_embeds

1165
1166
1167
1168
1169
1170
1171
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1172
        inputs_embeds: Optional[torch.Tensor] = None,
1173
        **kwargs: object,
1174
    ) -> Union[torch.Tensor, IntermediateTensors]:
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
            pixel_values: Pixel values to be fed to a model.
                `None` if no images are passed.
            image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
                `None` if no images are passed.
            pixel_values_videos: Pixel values of videos to be fed to a model.
                `None` if no videos are passed.
            video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
                `None` if no videos are passed.
        """
1194

1195
        if intermediate_tensors is not None:
1196
            inputs_embeds = None
1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            # We need to check for usage of mrope here in case there is
            # multimodal data.
            # TODO (ywang96): move this to model runner in V1.
            if multimodal_embeddings is not None and uses_mrope(self.config):
                assert positions.ndim == 2 and positions.size(0) == 3, (
                    "multimodal section rotary embedding requires "
                    f"(3, seq_len) positions, but got {positions.size()}")

            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings)
            input_ids = None
1214

1215
        hidden_states = self.language_model.model(
1216
1217
1218
1219
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
1220
            intermediate_tensors=intermediate_tensors,
1221
1222
1223
1224
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1225
1226
1227
1228
1229
1230
1231
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
1232
1233
1234
1235
1236
1237

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
1238
        return self.language_model.sample(logits, sampling_metadata)
1239

1240
1241
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1242
1243

        loader = AutoWeightsLoader(self)
1244
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1245
1246
1247
1248
1249
1250
1251
1252
1253

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="visual.",
            tower_model="visual.merger.")