qwen2_vl.py 58.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
27

28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal, TypeAlias
31
32
33
34
35

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
36
from transformers import AutoConfig, BatchFeature, PretrainedConfig
37
from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor
38
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
39
40
41
    Qwen2VLConfig,
    Qwen2VLVisionConfig,
)
42
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
43
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
44

45
from vllm.attention.backends.registry import _Backend
46
47
48
49
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
50
from vllm.config import VllmConfig
51
from vllm.config.multimodal import BaseDummyOptions
52
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
53
54
55
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
56
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
57
from vllm.model_executor.layers.quantization import QuantizationConfig
58
from vllm.model_executor.layers.rotary_embedding.common import (
59
60
    dispatch_rotary_emb_function,
)
61
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
from vllm.model_executor.models.module_mapping import MultiModelKeys
63
from vllm.multimodal import MULTIMODAL_REGISTRY
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
85
from vllm.multimodal.profiling import BaseDummyInputsBuilder
86
from vllm.sequence import IntermediateTensors
87
from vllm.transformers_utils.tokenizer import AnyTokenizer
88
from vllm.utils.tensor_schema import TensorSchema, TensorShape
89

90
91
92
93
94
95
96
97
98
99
100
101
102
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMRoPE,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
103
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
104

105
106
logger = init_logger(__name__)

107
# For profile run
108
_MAX_FRAMES_PER_VIDEO = 14
109

110
111
112
# === Vision Inputs === #


113
class Qwen2VLImagePixelInputs(TensorSchema):
114
    """
115
116
117
118
119
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
120

121
    Historical context:
122
        - pixel_values shape: (num_patches, num_channels * patch_size *
123
124
125
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
126
    """
127

128
    type: Literal["pixel_values"]
129

130
131
132
133
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]
134

135
136
137
138
139
140
141
142
143
144
145
146
    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


class Qwen2VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
147

148
149
150
151
152
153
154
    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
155
    """
156

157
    type: Literal["image_embeds"]
158

159
160
161
162
163
164
165
166
167
    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]
168
169


170
Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs
171
172


173
174
175
176
177
class Qwen2VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each video over each prompt in
              the batch
178
        - ctps: Number of channels * temporal_patch_size * patch_size *
179
180
          patch_size
        - nv: Number of videos
181

182
    Historical context:
183
        - pixel_values_videos shape: (num_patches, num_channels *
184
185
186
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
187
    """
188

189
    type: Literal["pixel_values_videos"]
190

191
192
193
194
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]
195

196
197
198
199
    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
200
201


202
203
204
205
206
207
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos
208

209
210
211
212
213
214
215
    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
216
    """
217

218
    type: Literal["video_embeds"]
219

220
221
222
223
224
225
226
227
228
    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]
229
230


231
Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs
232

233
234
235
236
237
238
239
# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
240
        hidden_features: int,
241
        act_layer: type[nn.Module] = QuickGELU,
242
        quant_config: QuantizationConfig | None = None,
243
        prefix: str = "",
244
        use_data_parallel: bool = False,
245
246
    ):
        super().__init__()
247
248
249
250
251
252
253
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
254
        self.act = act_layer()
255
256
257
258
259
260
261
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
276
277
278
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
279
280


281
282
283
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
284
285
286
287
288
289
290
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
291
292
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
293
    sin = repeat(
294
295
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
296
297
    return torch.cat(
        [
298
299
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
300
301
302
303
304
        ],
        dim=-1,
    )


305
306
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
307
308
309
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
310
    output = rotary_emb_function(t_, cos, sin).type_as(t)
311
312
313
314
315
316
    return output


class Qwen2VisionAttention(nn.Module):
    def __init__(
        self,
317
318
319
        embed_dim: int,
        num_heads: int,
        projection_size: int,
320
        quant_config: QuantizationConfig | None = None,
321
        prefix: str = "",
322
        use_data_parallel: bool = False,
323
324
325
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
326
327
328
329
330
        self.tp_size = (
            1
            if use_data_parallel
            else parallel_state.get_tensor_model_parallel_world_size()
        )
331
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
332
        self.hidden_size_per_attention_head = dist_utils.divide(
333
334
            projection_size, num_heads
        )
335
        self.num_attention_heads_per_partition = dist_utils.divide(
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            num_heads, self.tp_size
        )

        self.qkv = ColumnParallelLinear(
            input_size=embed_dim,
            output_size=3 * projection_size,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
353
354

        # Detect attention implementation.
355
356
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
357
358
            dtype=torch.get_default_dtype(),
        )
359
        self.use_upstream_fa = False
360

361
362
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
363
364
365
                self.attn_backend,
                self.use_upstream_fa,
            )
366
        )
367

368
        if self.attn_backend not in {
369
370
371
372
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
373
374
        }:
            raise RuntimeError(
375
376
                f"Qwen2-VL does not support {self.attn_backend} backend now."
            )
377

378
        self.is_flash_attn_backend = self.attn_backend in {
379
380
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
381
        }
382

383
384
385
386
387
388
389
390
391
392
393
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
            qkv = tensor_model_parallel_all_gather(qkv)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
394
395
396
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
397
398
399
400
401
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
402
403
404
405
406
407
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
408
409
410
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

411
    def forward(
412
413
414
415
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
416
417
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
418
    ) -> torch.Tensor:
419
420
        # [s, b, c] --> [s, b, 3 * head * head_dim]
        x, _ = self.qkv(x)
421

422
423
        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
424
425
        batch_size = q.shape[1]

426
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
427
        if rotary_pos_emb is not None:
428
429
430
431
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
432

433
        if self.is_flash_attn_backend:
434
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
451
        elif self.attn_backend == _Backend.TORCH_SDPA:
燃's avatar
committed
452
453
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
454
            for i in range(1, len(cu_seqlens)):
燃's avatar
committed
455
456
457
458
459
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
460
461
462
463
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
燃's avatar
committed
464
465
466
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
467
468
469
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
470
        elif self.attn_backend == _Backend.XFORMERS:
471
472
473
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

474
475
476
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
477
478

            context_layer = xops.memory_efficient_attention_forward(
479
480
481
482
483
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
484
485
486
487
488
489
490
491
492
493
494

        output, _ = self.proj(context_layer)
        return output


class Qwen2VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
495
        act_layer: type[nn.Module] = QuickGELU,
496
497
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
498
        prefix: str = "",
499
        use_data_parallel: bool = False,
500
501
502
503
504
505
506
507
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        self.attn = Qwen2VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
        )
        self.mlp = Qwen2VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
524

525
    def forward(
526
527
528
529
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
530
531
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
532
533
534
535
536
537
538
539
540
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

541
542
543
544
545
546
547
548
549
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
550
        in_channels: int = 3,
551
552
553
554
555
556
557
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

558
        kernel_size = (temporal_patch_size, patch_size, patch_size)
559
560
561
562
563
564
565
        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )
566
567
568

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
569
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
570
571
572
573
574
575
576
577
578
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
579
        norm_layer: Callable[[int], nn.Module] | None = None,
580
        spatial_merge_size: int = 2,
581
        quant_config: QuantizationConfig | None = None,
582
        prefix: str = "",
583
        use_data_parallel: bool = False,
584
585
586
587
588
589
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                    disable_tp=use_data_parallel,
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                    disable_tp=use_data_parallel,
                ),
            ]
        )
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
628
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
629
630
631
632
633
634
635
636
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
637
638
639
640
641
642
643
644
645
646
647
648
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
649
650
651
652
653
654
655
656
657
658
659
660
661
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
662
        quant_config: QuantizationConfig | None = None,
663
        prefix: str = "",
664
        use_data_parallel: bool = False,
665
666
667
    ) -> None:
        super().__init__()

668
669
670
671
672
673
674
675
676
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio
677

678
679
680
        self.use_data_parallel = use_data_parallel
        self.out_hidden_size = vision_config.hidden_size

681
        self.spatial_merge_size = spatial_merge_size
682
683
        self.num_heads = num_heads
        self.embed_dim = embed_dim
684
685
686
687

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
688
            in_channels=in_channels,
689
690
691
692
693
694
695
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)

696
697
698
699
700
701
702
703
704
705
706
707
708
709
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                )
                for layer_idx in range(depth)
            ]
        )
710
711
712
713
714
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
715
            prefix=f"{prefix}.merger",
716
            use_data_parallel=use_data_parallel,
717
        )
718
        self.attn_backend = get_vit_attn_backend(
719
720
721
722
723
            head_size=head_dim, dtype=torch.get_default_dtype()
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
724
            self.attn_backend = _Backend.FLASH_ATTN
725
726
727

    @property
    def dtype(self) -> torch.dtype:
728
        return self.patch_embed.proj.weight.dtype
729
730
731

    @property
    def device(self) -> torch.device:
732
        return self.patch_embed.proj.weight.device
733

734
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
735
        pos_ids = []
736
        max_grid_size = 0
737
738
739
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
761
            max_grid_size = max(max_grid_size, h, w)
762
763
764
765
766
        pos_ids = torch.cat(pos_ids, dim=0)
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

767
    def compute_attn_mask_seqlen(
768
        self, cu_seqlens: torch.Tensor
769
    ) -> tuple[int | None, list[int] | None]:
770
        max_seqlen, seqlens = None, None
771
772
773
774
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
775
776
777
778
779
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

780
781
782
    def forward(
        self,
        x: torch.Tensor,
783
        grid_thw: list[list[int]],
784
785
786
787
788
789
790
791
792
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # compute cu_seqlens
793
        grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
794
795
796
        cu_seqlens = torch.repeat_interleave(
            grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
797
798
799
800
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
801

802
803
        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
804
        for blk in self.blocks:
805
806
807
808
809
810
811
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
812
813
814

        # adapter
        x = self.merger(x)
815

816
817
        return x

818
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
819
820
821
822
823
824
825
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
826
        loaded_params: set[str] = set()
827
828

        for name, loaded_weight in weights:
829
            for param_name, weight_name, shard_id in stacked_params_mapping:
830
831
832
833
834
835
836
837
838
839
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
840
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
841
842
843
844
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

845

846
def _create_qwen2vl_field_factory(
847
    spatial_merge_size: int,
848
849
) -> Callable[
    [Mapping[str, torch.Tensor]],
850
    Mapping[str, MultiModalFieldConfig],
851
852
853
854
]:
    def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_pixel_grid_sizes = image_grid_thw.prod(-1)
855
856
857
        image_embed_grid_sizes = (
            image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
        )
858
859
860

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)
861
862
863
        video_embed_grid_sizes = (
            video_grid_sizes // spatial_merge_size // spatial_merge_size
        )
864
865
866

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
867
868
                "image", image_pixel_grid_sizes
            ),
869
            image_embeds=MultiModalFieldConfig.flat_from_sizes(
870
871
                "image", image_embed_grid_sizes
            ),
872
873
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
874
875
                "video", video_grid_sizes
            ),
876
            video_embeds=MultiModalFieldConfig.flat_from_sizes(
877
878
                "video", video_embed_grid_sizes
            ),
879
880
881
882
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )

    return _qwen2vl_field_config
883

884

Roger Wang's avatar
Roger Wang committed
885
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
886
887
888
889
    def __init__(self, spatial_merge_size: int, *args, **kwargs):
        self._spatial_merge_size = spatial_merge_size
        super().__init__(*args, **kwargs)

890
891
    def _parse_image_data(
        self,
892
893
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
894
        if isinstance(data, dict):
895
896
897
898
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"image_embeds", "image_grid_thw"},
899
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
900
            )
901
902
903
904

        return super()._parse_image_data(data)

    def _parse_video_data(
905
        self,
906
907
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
    ) -> ModalityDataItems[Any, Any] | None:
908
        if isinstance(data, dict):
909
910
911
912
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={"video_embeds", "video_grid_thw"},
913
                fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size),
914
            )
915
916
917
918

        return super()._parse_video_data(data)


919
920
class Qwen2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
921
922
        return self.ctx.get_hf_config(Qwen2VLConfig)

923
    def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
924
925
        return self.ctx.get_hf_processor(
            Qwen2VLProcessor,
926
            use_fast=kwargs.pop("use_fast", True),
927
928
929
            **kwargs,
        )

930
931
    def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
932

933
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
934
935
        return {"image": None, "video": None}

936
937
938
939
940
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
941
942
943
944
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

945
946
947
948
949
950
951
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
952
        image_processor: Qwen2VLImageProcessor | None,
953
    ) -> tuple[ImageSize, int]:
954
955
956
957
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
958
        vision_config = hf_config.vision_config
959
960
961
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
962

963
964
965
966
967
968
969
970
        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
971
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
972
        else:
973
            preprocessed_size = ImageSize(width=image_width, height=image_height)
974

975
976
977
978
979
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
980
981
982
983
984
985
986
987
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

988
    def get_num_image_tokens(
989
990
991
992
        self,
        *,
        image_width: int,
        image_height: int,
993
        image_processor: Qwen2VLImageProcessor | None,
994
995
996
997
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
998
            num_frames=1,
999
            image_processor=image_processor,
1000
1001
1002
        )
        return num_image_tokens

1003
    def get_num_video_tokens(
1004
1005
1006
1007
1008
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
1009
        image_processor: Qwen2VLImageProcessor | None,
1010
1011
1012
1013
1014
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
1015
            image_processor=image_processor,
1016
1017
1018
        )
        return num_video_tokens

1019
    def get_image_size_with_most_features(self) -> ImageSize:
1020
1021
1022
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
1023
            num_frames=1,
1024
            image_processor=None,
1025
1026
1027
        )
        return max_image_size

1028
1029
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
1030

1031
        return self.get_num_image_tokens(
1032
1033
            image_width=target_width,
            image_height=target_height,
1034
            image_processor=None,
1035
        )
1036

1037
    def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
1038
        target_width, target_height = self.get_image_size_with_most_features()
1039

1040
        num_frames = start_num_frames
1041
1042
1043

        while True:
            next_num_frames = num_frames + 1
1044
            next_max_tokens = self.get_num_video_tokens(
1045
1046
1047
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
1048
                image_processor=None,
1049
            )
1050

1051
            if next_max_tokens > max_tokens:
1052
1053
1054
1055
1056
1057
                break

            num_frames = next_num_frames

        return num_frames

1058
1059
1060
1061
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1062
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
1063
1064
    ) -> int:
        max_videos = mm_counts.get("video", 0)
1065

1066
        max_total_frames = self._get_max_video_frames(seq_len)
1067
1068
1069
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), max_frames_per_video
        )
1070

1071
        return max(max_frames_per_video, 1)
1072

1073
1074
1075
1076
1077
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
1078
        target_width, target_height = self.get_image_size_with_most_features()
1079

1080
        return self.get_num_video_tokens(
1081
1082
            image_width=target_width,
            image_height=target_height,
1083
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1084
            image_processor=None,
1085
1086
        )

1087
1088

class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
1089
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
1090
1091
1092
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1093
        hf_processor = self.info.get_hf_processor()
1094
1095
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token
1096

1097
1098
1099
1100
1101
1102
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1103
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1104
1105
1106
1107
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1108
1109
1110
1111
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1112

1113
1114
1115
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1116
        return {
1117
1118
1119
1120
1121
1122
1123
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1124
1125
                width=target_width,
                height=target_height,
1126
                num_frames=target_num_frames,
1127
                num_videos=num_videos,
1128
                overrides=video_overrides,
1129
            ),
1130
1131
        }

1132

1133
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
1134
    def _get_data_parser(self) -> MultiModalDataParser:
1135
        return Qwen2VLMultiModalDataParser(
1136
1137
            self.info.get_hf_config().vision_config.spatial_merge_size
        )
1138

1139
    def _get_prompt_updates(
1140
1141
        self,
        mm_items: MultiModalDataItems,
1142
        hf_processor_mm_kwargs: Mapping[str, Any],
1143
        out_mm_kwargs: MultiModalKwargsItems,
1144
    ) -> Sequence[PromptUpdate]:
1145
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1146
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1147
1148
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
1149
1150

        placeholder = {
1151
1152
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
1153
        }
1154

1155
1156
1157
        merge_length = image_processor.merge_size**2

        def get_replacement_qwen2vl(item_idx: int, modality: str):
1158
1159
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1160
1161
            assert isinstance(grid_thw, torch.Tensor)

1162
1163
            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens
1164
1165
1166
1167

        return [
            PromptReplacement(
                modality=modality,
1168
                target=[placeholder[modality]],
1169
1170
1171
                replacement=partial(get_replacement_qwen2vl, modality=modality),
            )
            for modality in ("image", "video")
1172
        ]
1173

1174
1175
1176
1177
1178
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1179
        return _create_qwen2vl_field_factory(
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class Qwen2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
1192
    # To ensure correct weight loading and mapping.
1193
1194
1195
1196
1197
1198
1199
1200
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
1201
1202
        }
    )
1203

1204
1205
    supports_encoder_tp_data = True

1206
1207
1208
1209
    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
1210
1211
1212
        image_grid_thw: list[list[int]] | torch.Tensor | None,
        video_grid_thw: list[list[int]] | torch.Tensor | None,
        second_per_grid_ts: list[float] | None = None,
1213
        context_len: int = 0,
1214
1215
        seq_len: int | None = None,
        audio_feature_lengths: torch.Tensor | None = None,
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get M-RoPE input positions for Qwen2-VL model."""
        if image_grid_thw is None:
            image_grid_thw = []
        if video_grid_thw is None:
            video_grid_thw = []
        if second_per_grid_ts is None:
            second_per_grid_ts = []

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
1230
        tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
1231
1232
1233

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
1234
1235
            input_tokens_tensor == vision_start_token_id
        ).squeeze(1)
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_videos > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

1283
1284
1285
1286
1287
            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
1288
1289
            text_len = ed - st

1290
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1291
            llm_pos_ids_list.append(
1292
1293
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1294

1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                    * video_second_per_grid_t
                    * tokens_per_second
                )
                .long()
                .flatten()
            )
1306

1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
1319
            llm_pos_ids_list.append(
1320
1321
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
1322
1323
1324
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
1325
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1326
1327
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
1328
1329
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )
1330
1331

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1332
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
1333
1334
1335
1336
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

1337
    @classmethod
1338
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1339
1340
1341
1342
1343
1344
1345
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1346
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1347
        super().__init__()
1348
        config: Qwen2VLConfig = vllm_config.model_config.hf_config
1349
1350
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
1351

1352
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1353
1354
1355
        self.config = config
        self.multimodal_config = multimodal_config

1356
1357
1358
        if multimodal_config.get_limit_per_prompt(
            "image"
        ) or multimodal_config.get_limit_per_prompt("video"):
1359
1360
1361
            self.visual = Qwen2VisionTransformer(
                config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1362
                quant_config=quant_config,
1363
                prefix=maybe_prefix(prefix, "visual"),
1364
                use_data_parallel=self.use_data_parallel,
1365
1366
1367
            )
        else:
            self.visual = None
1368

1369
1370
1371
1372
1373
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
1374

1375
        self.make_empty_intermediate_tensors = (
1376
1377
            self.language_model.make_empty_intermediate_tensors
        )
1378

1379
1380
1381
    def _validate_and_reshape_mm_tensor(
        self, mm_input: object, name: str
    ) -> torch.Tensor:
1382
        if not isinstance(mm_input, (torch.Tensor, list)):
1383
            raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
1384
1385
1386
1387
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
1388
1389
1390
1391
1392
                raise ValueError(
                    f"{name} should be 2D or batched 3D tensor. "
                    f"Got ndim: {mm_input.ndim} "
                    f"(shape={mm_input.shape})"
                )
1393
            return mm_input.reshape(-1, mm_input.shape[-1])
1394
1395
1396
1397
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
1398
        self, **kwargs: object
1399
    ) -> Qwen2VLImageInputs | None:
1400
        pixel_values = kwargs.pop("pixel_values", None)
1401
        image_embeds = kwargs.pop("image_embeds", None)
1402
1403
        image_grid_thw = kwargs.pop("image_grid_thw", None)

1404
        if pixel_values is None and image_embeds is None:
1405
1406
            return None

1407
1408
        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
1409
1410
                pixel_values, "image pixel values"
            )
1411
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1412
1413
                image_grid_thw, "image grid_thw"
            )
1414

1415
1416
1417
1418
1419
            return Qwen2VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1420
1421

        if image_embeds is not None:
1422
            image_embeds = self._validate_and_reshape_mm_tensor(
1423
1424
                image_embeds, "image embeds"
            )
1425
            image_grid_thw = self._validate_and_reshape_mm_tensor(
1426
1427
                image_grid_thw, "image grid_thw"
            )
1428

1429
1430
1431
1432
1433
            return Qwen2VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
1434
1435

    def _parse_and_validate_video_input(
1436
        self, **kwargs: object
1437
    ) -> Qwen2VLVideoInputs | None:
1438
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
1439
        video_embeds = kwargs.pop("video_embeds", None)
1440
1441
        video_grid_thw = kwargs.pop("video_grid_thw", None)

1442
        if pixel_values_videos is None and video_embeds is None:
1443
1444
            return None

1445
1446
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
1447
1448
                pixel_values_videos, "video pixel values"
            )
1449
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1450
1451
                video_grid_thw, "video grid_thw"
            )
1452
1453
1454
1455
1456
1457
1458
1459
1460

            return Qwen2VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
1461
1462
                video_embeds, "video embeds"
            )
1463
            video_grid_thw = self._validate_and_reshape_mm_tensor(
1464
1465
                video_grid_thw, "video grid_thw"
            )
1466

1467
1468
1469
1470
1471
            return Qwen2VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )
1472

1473
    def _process_image_input(
1474
1475
        self, image_input: Qwen2VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1476
1477
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1478
        grid_thw_list = grid_thw.tolist()
1479

1480
        if image_input["type"] == "image_embeds":
1481
            image_embeds = image_input["image_embeds"]
1482
        else:
1483
            pixel_values = image_input["pixel_values"]
1484
1485

            if self.use_data_parallel:
1486
1487
1488
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
                )
1489
            else:
1490
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
1491
1492
1493

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
1494
1495
1496
1497
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1498

1499
        return image_embeds.split(sizes)
1500
1501

    def _process_video_input(
1502
1503
        self, video_input: Qwen2VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1504
1505
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1506
        grid_thw_list = grid_thw.tolist()
1507

1508
        if video_input["type"] == "video_embeds":
1509
            video_embeds = video_input["video_embeds"]
1510
        else:
1511
            pixel_values_videos = video_input["pixel_values_videos"]
1512
            if self.use_data_parallel:
1513
1514
1515
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
                )
1516
            else:
1517
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
1518

1519
1520
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1521
1522
1523
1524
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1525

1526
        return video_embeds.split(sizes)
1527
1528
1529
1530
1531
1532
1533

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1544
1545

        return modalities
1546

1547
1548
1549
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1550
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1551
1552
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1553
            return []
1554

1555
1556
1557
1558
1559
1560
1561
1562
1563
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1564
1565
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1566
1567
1568
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1569
                multimodal_embeddings += tuple(video_embeddings)
1570
1571
1572

        return multimodal_embeddings

1573
1574
1575
1576
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1577
1578
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1579
        **kwargs: object,
1580
    ) -> torch.Tensor | IntermediateTensors:
1581
1582
1583
1584
1585
1586
1587
1588
1589
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1590
1591
1592
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1593
        """
1594

1595
        if intermediate_tensors is not None:
1596
            inputs_embeds = None
1597

1598
        hidden_states = self.language_model.model(
1599
1600
            input_ids=input_ids,
            positions=positions,
1601
            intermediate_tensors=intermediate_tensors,
1602
1603
1604
1605
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

1606
1607
1608
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1609
    ) -> torch.Tensor | None:
1610
        return self.language_model.compute_logits(hidden_states)
1611

1612
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1613
1614
1615
1616
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1617
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1618
1619
1620
1621
1622
1623
1624

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1625
1626
1627
            connector="visual.merger.",
            tower_model="visual.",
        )
1628
1629
1630
1631
1632
1633
1634
1635
1636


class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
    pass


class Tarsier2ImageProcessor(Qwen2VLImageProcessor):
    def __init__(
        self,
1637
        size: dict[str, int] | None = None,
1638
1639
1640
1641
1642
1643
        **kwargs,
    ) -> None:
        if size is not None and "min_pixels" in size and "max_pixels" in size:
            # Remap if Tarsier2-specific format is provided
            remapped_size = {
                "shortest_edge": size["min_pixels"],
1644
                "longest_edge": size["max_pixels"],
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
            }
            super().__init__(size=remapped_size, **kwargs)
        else:
            super().__init__(size=size, **kwargs)


class Tarsier2Processor(Qwen2VLProcessor):
    def __init__(
        self,
        vision_config: dict,
        tokenizer: AnyTokenizer,
        **kwargs,
    ):
        self.image_processor = Tarsier2ImageProcessor(**vision_config)
1659
1660
1661
1662
1663
        super().__init__(
            image_processor=self.image_processor,
            tokenizer=tokenizer,
            video_processor=Qwen2VLVideoProcessor(**vision_config),
            chat_template=None,
1664
1665
            **kwargs,
        )
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684


class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> Qwen2VLConfig:
        model_path = self.ctx.model_config.model
        original_config = AutoConfig.from_pretrained(model_path)
        config_dict = original_config.to_dict()
        correct_config = Qwen2VLConfig.from_dict(config_dict)

        return correct_config

    def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor:
        return Tarsier2Processor(
            vision_config=self.ctx.get_hf_image_processor_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )

    def get_image_processor(self) -> Tarsier2ImageProcessor:
1685
        return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config())
1686
1687


1688
1689
1690
1691
1692
@MULTIMODAL_REGISTRY.register_processor(
    Tarsier2MultiModalProcessor,
    info=Tarsier2ProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
1693
class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
1694
1695
1696
1697
1698
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "vision_tower.": "visual.",
        }
    )
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig
        # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig.
        config = vllm_config.model_config.hf_config
        qwen2vl_config = config.text_config
        qwen2vl_config.architectures = config.architectures
        vllm_config.model_config.hf_config = qwen2vl_config
        super().__init__(vllm_config=vllm_config, prefix=prefix)

1709
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1710
1711
1712
1713
        skip_prefixes = []
        if self.visual is None:
            skip_prefixes.extend(["visual."])
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1714
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)