qwen2_vl.py 60.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27
from collections.abc import Iterable, Mapping, Sequence
28
from functools import partial
29
from typing import Annotated, Any, Callable, Literal, Optional, Union
30
31
32
33
34

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
35
from transformers import AutoConfig, BatchFeature, PretrainedConfig
36
37
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
                                          Qwen2VLProcessor)
38
39
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
    Qwen2VLConfig, Qwen2VLVisionConfig)
40
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
41
42
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
    Qwen2VLVideoProcessor)
43

44
from vllm.attention.backends.registry import _Backend
45
46
from vllm.attention.layer import (check_upstream_fa_availability,
                                  maybe_get_vit_flash_attn_backend)
47
from vllm.config import VllmConfig
48
from vllm.config.multimodal import BaseDummyOptions
49
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
50
51
52
53
54
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
55
from vllm.model_executor.layers.quantization import QuantizationConfig
56
57
from vllm.model_executor.layers.rotary_embedding.common import (
    dispatch_rotary_emb_function)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.model_executor.models.module_mapping import MultiModelKeys
60
from vllm.multimodal import MULTIMODAL_REGISTRY
61
from vllm.multimodal.inputs import (ImageItem, ModalityData,
62
                                    MultiModalDataDict, MultiModalFieldConfig,
63
                                    MultiModalKwargsItems, VideoItem)
64
65
66
from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
                                   ModalityDataItems, MultiModalDataItems,
                                   MultiModalDataParser)
67
from vllm.multimodal.processing import (BaseMultiModalProcessor,
68
69
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
70
from vllm.multimodal.profiling import BaseDummyInputsBuilder
71
from vllm.sequence import IntermediateTensors
72
from vllm.transformers_utils.tokenizer import AnyTokenizer
73
from vllm.utils.tensor_schema import TensorSchema, TensorShape
74

75
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
76
                         SupportsMultiModal, SupportsPP)
77
from .utils import (AutoWeightsLoader, WeightsMapper,
78
                    init_vllm_registered_model, maybe_prefix)
79
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
80

81
82
logger = init_logger(__name__)

83
# For profile run
84
_MAX_FRAMES_PER_VIDEO = 14
85

86
87
88
# === Vision Inputs === #


89
class Qwen2VLImagePixelInputs(TensorSchema):
90
    """
91
92
93
94
95
96
97
98
99
100
101
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    
    Historical context:
        - pixel_values shape: (num_patches, num_channels * patch_size * 
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
102
    """
103
    type: Literal["pixel_values"]
104

105
106
107
108
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
109

110
111
112
113
114
115
116
117
118
119
120
121
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
122
    
123
124
125
126
127
128
129
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
130
    """
131
    type: Literal["image_embeds"]
132

133
134
135
136
137
138
139
140
141
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
142
143
144
145
146
147


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
                           Qwen2VLImageEmbeddingInputs]


148
149
150
151
152
153
154
155
156
157
158
159
160
161
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
        - ctps: Number of channels * temporal_patch_size * patch_size * 
          patch_size
        - nv: Number of videos
    
    Historical context:
        - pixel_values_videos shape: (num_patches, num_channels * 
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
162
    """
163
    type: Literal["pixel_values_videos"]
164

165
166
167
168
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
169

170
171
172
173
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
174
175


176
177
178
179
180
181
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
182
    
183
184
185
186
187
188
189
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
190
    """
191
    type: Literal["video_embeds"]
192

193
194
195
196
197
198
199
200
201
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
202
203
204
205
206


Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
                           Qwen2VLVideoEmbeddingInputs]

207
208
209
210
211
212
213
214
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
215
        hidden_features: int,
216
        act_layer: type[nn.Module] = QuickGELU,
217
        quant_config: Optional[QuantizationConfig] = None,
218
        prefix: str = "",
219
        use_data_parallel: bool = False,
220
221
222
223
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(in_features,
                                        hidden_features,
224
                                        quant_config=quant_config,
225
226
                                        prefix=f"{prefix}.fc1",
                                        disable_tp=use_data_parallel)
227
228
229
        self.act = act_layer()
        self.fc2 = RowParallelLinear(hidden_features,
                                     in_features,
230
                                     quant_config=quant_config,
231
232
                                     prefix=f"{prefix}.fc2",
                                     disable_tp=use_data_parallel)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)


def apply_rotary_emb_torch(x: torch.Tensor,
                           cos: torch.Tensor,
                           sin: torch.Tensor,
                           interleaved: bool = False) -> torch.Tensor:
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )


def apply_rotary_pos_emb_vision(t: torch.Tensor,
278
                                freqs: torch.Tensor) -> torch.Tensor:
279
    rotary_emb_function = dispatch_rotary_emb_function()
280
281
282
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
283
    output = rotary_emb_function(t_, cos, sin).type_as(t)
284
285
286
287
288
289
290
    return output


class Qwen2VisionAttention(nn.Module):

    def __init__(
        self,
291
292
293
        embed_dim: int,
        num_heads: int,
        projection_size: int,
294
        quant_config: Optional[QuantizationConfig] = None,
295
        prefix: str = "",
296
        use_data_parallel: bool = False,
297
298
299
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
300
301
        self.tp_size = (1 if use_data_parallel else
                        parallel_state.get_tensor_model_parallel_world_size())
302
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
303
304
305
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
306
            num_heads, self.tp_size)
307
308
309

        self.qkv = ColumnParallelLinear(input_size=embed_dim,
                                        output_size=3 * projection_size,
310
                                        quant_config=quant_config,
311
312
                                        prefix=f"{prefix}.qkv",
                                        disable_tp=use_data_parallel)
313
314
        self.proj = RowParallelLinear(input_size=projection_size,
                                      output_size=embed_dim,
315
                                      quant_config=quant_config,
316
317
                                      prefix=f"{prefix}.proj",
                                      disable_tp=use_data_parallel)
318
319

        # Detect attention implementation.
320
321
322
323
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
            dtype=torch.get_default_dtype())
        self.use_upstream_fa = False
324
325
326
327
328
329

        self.attn_backend, self.flash_attn_varlen_func \
            = maybe_get_vit_flash_attn_backend(
                self.attn_backend,
                self.use_upstream_fa,
            )
330

331
        if self.attn_backend not in {
332
333
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
                _Backend.ROCM_AITER_FA
334
335
336
        }:
            raise RuntimeError(
                f"Qwen2-VL does not support {self.attn_backend} backend now.")
337

338
339
340
        self.is_flash_attn_backend = self.attn_backend in {
            _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
        }
341

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
            splitter = partial(dist_utils.split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
                     self.hidden_size_per_attention_head)
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

365
    def forward(
366
367
368
369
370
371
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
372
373
    ) -> torch.Tensor:

374
375
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
376

377
378
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
379
380
        batch_size = q.shape[1]

381
382
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
383
        if rotary_pos_emb is not None:
384
385
386
387
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
388

389
        if self.is_flash_attn_backend:
390

391
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
392

393
394
395
396
397
398
399
400
401
            output = self.flash_attn_varlen_func(q,
                                                 k,
                                                 v,
                                                 cu_seqlens_q=cu_seqlens,
                                                 cu_seqlens_k=cu_seqlens,
                                                 max_seqlen_q=max_seqlen,
                                                 max_seqlen_k=max_seqlen,
                                                 dropout_p=0.0,
                                                 causal=False)
402
403

            context_layer = rearrange(output,
404
405
                                      "(b s) h d -> s b (h d)",
                                      b=batch_size).contiguous()
406
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
407
408
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
409
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                                 for x in [q_i, k_i, v_i])
                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
424
425
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
426
        elif self.attn_backend == _Backend.XFORMERS:
427
428
429
430
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
431
432
                                                       kv_seqlen=None,
                                                       device=q.device)
433
434
435

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
436
437
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
438
439
440
441
442
443
444
445
446
447
448
449

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
450
        act_layer: type[nn.Module] = QuickGELU,
451
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
452
        quant_config: Optional[QuantizationConfig] = None,
453
        prefix: str = "",
454
        use_data_parallel: bool = False,
455
456
457
458
459
460
461
462
463
464
465
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.attn = Qwen2VisionAttention(embed_dim=dim,
                                         num_heads=num_heads,
                                         projection_size=dim,
466
                                         quant_config=quant_config,
467
468
                                         prefix=f"{prefix}.attn",
                                         use_data_parallel=use_data_parallel)
469
470
471
        self.mlp = Qwen2VisionMLP(dim,
                                  mlp_hidden_dim,
                                  act_layer=act_layer,
472
                                  quant_config=quant_config,
473
474
                                  prefix=f"{prefix}.mlp",
                                  use_data_parallel=use_data_parallel)
475

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

492
493
494
495
496
497
498
499
500
501
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
502
        in_channels: int = 3,
503
504
505
506
507
508
509
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

510
511
        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(in_channels,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
531
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
532
533
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
534
        prefix: str = "",
535
        use_data_parallel: bool = False,
536
537
538
539
540
541
542
543
544
545
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
        self.mlp = nn.ModuleList([
            ColumnParallelLinear(self.hidden_size,
                                 self.hidden_size,
                                 bias=True,
546
                                 quant_config=quant_config,
547
548
                                 prefix=f"{prefix}.mlp.0",
                                 disable_tp=use_data_parallel),
549
550
551
552
            nn.GELU(),
            RowParallelLinear(self.hidden_size,
                              d_model,
                              bias=True,
553
                              quant_config=quant_config,
554
555
                              prefix=f"{prefix}.mlp.2",
                              disable_tp=use_data_parallel),
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
                                                / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
606
        prefix: str = "",
607
        use_data_parallel: bool = False,
608
609
610
    ) -> None:
        super().__init__()

611
612
613
614
615
616
617
618
619
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
620

621
622
623
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

624
        self.spatial_merge_size = spatial_merge_size
625
626
        self.num_heads = num_heads
        self.embed_dim = embed_dim
627
628
629
630

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
631
            in_channels=in_channels,
632
633
634
635
636
637
638
639
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([
640
641
642
643
644
            Qwen2VisionBlock(dim=embed_dim,
                             num_heads=num_heads,
                             mlp_ratio=mlp_ratio,
                             norm_layer=norm_layer,
                             quant_config=quant_config,
645
646
                             prefix=f"{prefix}.blocks.{layer_idx}",
                             use_data_parallel=use_data_parallel)
647
            for layer_idx in range(depth)
648
649
650
651
652
653
        ])
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
654
            prefix=f"{prefix}.merger",
655
            use_data_parallel=use_data_parallel,
656
        )
657
658
659
660
661
662
        self.attn_backend = get_vit_attn_backend(
            head_size=head_dim, dtype=torch.get_default_dtype())
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(
                torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
663
664
665

    @property
    def dtype(self) -> torch.dtype:
666
        return self.patch_embed.proj.weight.dtype
667
668
669

    @property
    def device(self) -> torch.device:
670
        return self.patch_embed.proj.weight.device
671

672
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
673
        pos_ids = []
674
        max_grid_size = 0
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
692
            max_grid_size = max(max_grid_size, h, w)
693
694
695
696
697
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

698
699
700
701
    def compute_attn_mask_seqlen(
            self, cu_seqlens: torch.Tensor
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
702
703
        if (self.attn_backend == _Backend.FLASH_ATTN
                or self.attn_backend == _Backend.ROCM_AITER_FA):
704
705
706
707
708
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

709
710
711
    def forward(
        self,
        x: torch.Tensor,
712
        grid_thw: list[list[int]],
713
714
715
716
717
718
719
720
721
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
722
723
724
        grid_thw_ = torch.tensor(grid_thw)
        cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
                                             grid_thw_[:, 0]).cumsum(
725
726
727
728
729
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
730

731
732
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
733
        for blk in self.blocks:
734
735
736
737
738
739
740
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
741
742
743

        # adapter
        x = self.merger(x)
744

745
746
        return x

747
748
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
749
750
751
752
753
754
755
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
756
        loaded_params: set[str] = set()
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

776

777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
def _create_qwen2vl_field_factory(
    spatial_merge_size: int
) -> Callable[
    [Mapping[str, torch.Tensor]],
        Mapping[str, MultiModalFieldConfig],
]:

    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
        image_embed_grid_sizes = (image_pixel_grid_sizes //
                                  spatial_merge_size // spatial_merge_size)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
        video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
                                  spatial_merge_size)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", image_pixel_grid_sizes),
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
                "image", image_embed_grid_sizes),
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
                "video", video_grid_sizes),
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
                "video", video_embed_grid_sizes),
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
809

810

Roger Wang's avatar
Roger Wang committed
811
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
812

813
814
815
816
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

817
818
819
    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
820
    ) -> Optional[ModalityDataItems[Any, Any]]:
821
        if isinstance(data, dict):
822
823
824
825
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
826
827
                fields_factory=_create_qwen2vl_field_factory(
                    self._spatial_merge_size),
828
            )
829
830
831
832

        return super()._parse_image_data(data)

    def _parse_video_data(
833
        self,
834
        data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
835
    ) -> Optional[ModalityDataItems[Any, Any]]:
836
        if isinstance(data, dict):
837
838
839
840
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
841
842
                fields_factory=_create_qwen2vl_field_factory(
                    self._spatial_merge_size),
843
            )
844
845
846
847

        return super()._parse_video_data(data)


848
class Qwen2VLProcessingInfo(BaseProcessingInfo):
849

850
    def get_hf_config(self):
851
852
        return self.ctx.get_hf_config(Qwen2VLConfig)

853
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
854
855
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
856
            use_fast=kwargs.pop("use_fast", True),
857
858
859
            **kwargs,
        )

860
861
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
862

863
864
865
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

866
867
868
869
870
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
871
872
873
874
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

875
876
877
878
879
880
881
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
882
        image_processor: Optional[Qwen2VLImageProcessor],
883
    ) -> tuple[ImageSize, int]:
884
885
886
887
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
888
        vision_config = hf_config.vision_config
889
890
891
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
892

893
894
895
896
897
898
899
900
901
902
903
904
905
906
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

907
908
909
910
911
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
912
913
914
915
916
917
918
919
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

920
    def get_num_image_tokens(
921
922
923
924
        self,
        *,
        image_width: int,
        image_height: int,
925
        image_processor: Optional[Qwen2VLImageProcessor],
926
927
928
929
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
930
            num_frames=1,
931
            image_processor=image_processor,
932
933
934
        )
        return num_image_tokens

935
    def get_num_video_tokens(
936
937
938
939
940
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
941
        image_processor: Optional[Qwen2VLImageProcessor],
942
943
944
945
946
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
947
            image_processor=image_processor,
948
949
950
        )
        return num_video_tokens

951
    def get_image_size_with_most_features(self) -> ImageSize:
952
953
954
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
955
            num_frames=1,
956
            image_processor=None,
957
958
959
        )
        return max_image_size

960
961
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
962

963
        return self.get_num_image_tokens(
964
965
            image_width=target_width,
            image_height=target_height,
966
            image_processor=None,
967
        )
968

969
970
971
    def _get_max_video_frames(self,
                              max_tokens: int,
                              start_num_frames: int = 1) -> int:
972
        target_width, target_height = self.get_image_size_with_most_features()
973

974
        num_frames = start_num_frames
975
976
977

        while True:
            next_num_frames = num_frames + 1
978
            next_max_tokens = self.get_num_video_tokens(
979
980
981
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
982
                image_processor=None,
983
            )
984

985
            if next_max_tokens > max_tokens:
986
987
988
989
990
991
                break

            num_frames = next_num_frames

        return num_frames

992
993
994
995
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
996
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
997
998
    ) -> int:
        max_videos = mm_counts.get("video", 0)
999

1000
        max_total_frames = self._get_max_video_frames(seq_len)
1001
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
1002
                                   max_frames_per_video)
1003

1004
        return max(max_frames_per_video, 1)
1005

1006
1007
1008
1009
1010
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1011
        target_width, target_height = self.get_image_size_with_most_features()
1012

1013
        return self.get_num_video_tokens(
1014
1015
            image_width=target_width,
            image_height=target_height,
1016
1017
            num_frames=self.get_num_frames_with_most_features(
                seq_len, mm_counts),
1018
            image_processor=None,
1019
1020
        )

1021
1022
1023

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):

1024
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1025
1026
1027
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1028
        hf_processor = self.info.get_hf_processor()
1029
1030
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1031

1032
1033
1034
1035
1036
1037
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1038
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
1039
1040
1041
1042
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1043
1044
1045
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = \
1046
            self.info.get_num_frames_with_most_features(seq_len, mm_counts)
1047

1048
1049
1050
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1051
        return {
1052
1053
1054
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
1055
1056
                                   num_images=num_images,
                                   overrides=image_overrides),
1057
1058
1059
1060
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
1061
                num_frames=target_num_frames,
1062
                num_videos=num_videos,
1063
                overrides=video_overrides,
1064
            )
1065
1066
        }

1067

1068
1069
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
                                 ):
1070

1071
    def _get_data_parser(self) -> MultiModalDataParser:
1072
1073
        return Qwen2VLMultiModalDataParser(
            self.info.get_hf_config().vision_config.spatial_merge_size)
1074

1075
    def _get_prompt_updates(
1076
1077
        self,
        mm_items: MultiModalDataItems,
1078
        hf_processor_mm_kwargs: Mapping[str, Any],
1079
        out_mm_kwargs: MultiModalKwargsItems,
1080
    ) -> Sequence[PromptUpdate]:
1081
1082
1083
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)
1084
1085
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1086
1087

        placeholder = {
1088
1089
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1090
        }
1091

1092
1093
1094
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1095
1096
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1097
1098
            assert isinstance(grid_thw, torch.Tensor)

1099
1100
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1101
1102
1103
1104

        return [
            PromptReplacement(
                modality=modality,
1105
                target=[placeholder[modality]],
1106
1107
1108
1109
                replacement=partial(get_replacement_qwen2vl,
                                    modality=modality),
            ) for modality in ("image", "video")
        ]
1110

1111
1112
1113
1114
1115
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1116
1117
1118
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size)(
                hf_inputs)
1119

1120

1121
1122
1123
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
                                        info=Qwen2VLProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
1124
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
1125
                                      SupportsLoRA, SupportsPP, SupportsMRoPE):
1126

1127
    # To ensure correct weight loading and mapping.
1128
1129
1130
1131
1132
1133
1134
1135
1136
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })
1137

1138
1139
    supports_encoder_tp_data = True

1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        second_per_grid_ts: Optional[list[float]] = None,
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(hf_config.vision_config,
                                    "tokens_per_second", 1.0)

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
            input_tokens_tensor == vision_start_token_id).squeeze(1)
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

            llm_grid_t, llm_grid_h, llm_grid_w = \
                t, h // spatial_merge_size, w // spatial_merge_size
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
                -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                       tokens_per_second).long().flatten()

            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                llm_grid_t, -1, llm_grid_w).flatten()
            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                llm_grid_t, llm_grid_h, -1).flatten()
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1252
1253
1254
1255
1256
1257
1258
1259
1260
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1261
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1262
        super().__init__()
1263
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1264
1265
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1266

1267
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1268
1269
1270
        self.config = config
        self.multimodal_config = multimodal_config

1271
1272
1273
1274
1275
        if multimodal_config.get_limit_per_prompt("image") or \
            multimodal_config.get_limit_per_prompt("video"):
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1276
                quant_config=quant_config,
1277
                prefix=maybe_prefix(prefix, "visual"),
1278
                use_data_parallel=self.use_data_parallel,
1279
1280
1281
            )
        else:
            self.visual = None
1282

1283
1284
1285
1286
1287
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1288

1289
1290
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
1291

1292
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
1293
1294
1295
1296
1297
1298
1299
1300
1301
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
1302
1303
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
1304
            return mm_input.reshape(-1, mm_input.shape[-1])
1305
1306
1307
1308
1309
1310
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
1311
        image_embeds = kwargs.pop("image_embeds", None)
1312
1313
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1314
        if pixel_values is None and image_embeds is None:
1315
1316
            return None

1317
1318
1319
1320
1321
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1322

1323
            return Qwen2VLImagePixelInputs(type="pixel_values",
1324
                                           pixel_values=pixel_values,
1325
1326
1327
                                           image_grid_thw=image_grid_thw)

        if image_embeds is not None:
1328
1329
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
1330
1331
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")
1332

1333
            return Qwen2VLImageEmbeddingInputs(type="image_embeds",
1334
1335
                                               image_embeds=image_embeds,
                                               image_grid_thw=image_grid_thw)
1336
1337
1338
1339

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1340
        video_embeds = kwargs.pop("video_embeds", None)
1341
1342
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1343
        if pixel_values_videos is None and video_embeds is None:
1344
1345
            return None

1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
                                               video_embeds=video_embeds,
                                               video_grid_thw=video_grid_thw)
1367

1368
1369
1370
1371
1372
    def _process_image_input(
            self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:

        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1373
        grid_thw_list = grid_thw.tolist()
1374

1375
        if image_input["type"] == "image_embeds":
1376
            image_embeds = image_input["image_embeds"]
1377
        else:
1378
            pixel_values = image_input["pixel_values"]
1379
1380
1381
1382
1383
1384
1385
1386
1387

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values,
                                                         grid_thw_list,
                                                         rope_type="rope_3d")
            else:
                image_embeds = self.visual(pixel_values,
                                           grid_thw=grid_thw_list)
1388
1389
1390

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1391
1392
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
1393

1394
        return image_embeds.split(sizes)
1395
1396
1397

    def _process_video_input(
            self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
1398

1399
1400
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1401
        grid_thw_list = grid_thw.tolist()
1402

1403
        if video_input["type"] == "video_embeds":
1404
            video_embeds = video_input["video_embeds"]
1405
        else:
1406
            pixel_values_videos = video_input["pixel_values_videos"]
1407
1408
1409
1410
1411
1412
1413
1414
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values_videos,
                                                         grid_thw_list,
                                                         rope_type="rope_3d")
            else:
                video_embeds = self.visual(pixel_values_videos,
                                           grid_thw=grid_thw_list)
1415

1416
1417
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1418
1419
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
1420

1421
        return video_embeds.split(sizes)
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key in ("pixel_values_videos",
                             "video_embeds") and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities
1439

1440
1441
1442
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1443
1444
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
1445

1446
1447
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1448
            return []
1449

1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
                multimodal_embeddings += video_embeddings
1465
1466
1467

        return multimodal_embeddings

1468
1469
1470
1471
1472
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1473
        inputs_embeds: Optional[torch.Tensor] = None,
1474
        **kwargs: object,
1475
    ) -> Union[torch.Tensor, IntermediateTensors]:
1476
1477
1478
1479
1480
1481
1482
1483
1484
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1485
1486
1487
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1488
        """
1489

1490
        if intermediate_tensors is not None:
1491
            inputs_embeds = None
1492

1493
        hidden_states = self.language_model.model(
1494
1495
            input_ids=input_ids,
            positions=positions,
1496
            intermediate_tensors=intermediate_tensors,
1497
1498
1499
1500
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1501
1502
1503
1504
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
1505
        return self.language_model.compute_logits(hidden_states)
1506

1507
1508
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1509

1510
1511
1512
1513
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1514
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1515
1516
1517
1518
1519
1520
1521

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1522
1523
1524
            connector="visual.merger.",
            tower_model="visual.",
        )
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):

    def __init__(
        self,
        size: Optional[dict[str, int]] = None,
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
                "longest_edge": size["max_pixels"]
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):

    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1558
1559
1560
1561
1562
1563
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
            **kwargs)
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):

    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
        return Tarsier2ImageProcessor(
            **self.ctx.get_hf_image_processor_config())


@MULTIMODAL_REGISTRY.register_processor(Tarsier2MultiModalProcessor,
                                        info=Tarsier2ProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "vision_tower.": "visual.",
    })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:

1608
1609
1610
1611
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1612
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)