qwen_vl.py 24.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11

# Adapted from
# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""

import copy
import math
import unicodedata
12
from collections.abc import Callable, Collection, Mapping, Sequence, Set
13
from functools import lru_cache, partial
14
from typing import Annotated, Literal, TypeAlias
15

16
import regex as re
17
18
19
20
import torch
from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
21
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
22
23
24
25
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput

from vllm.config import VllmConfig
26
from vllm.config.multimodal import BaseDummyOptions
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
32
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
33
34
35
36
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
42
from vllm.multimodal.parse import MultiModalDataItems
43
44
45
46
47
48
49
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
50
from vllm.multimodal.profiling import BaseDummyInputsBuilder
51
from vllm.sequence import IntermediateTensors
52
from vllm.utils.tensor_schema import TensorSchema, TensorShape
53

54
55
56
57
58
59
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
60
61
62
from .qwen import QWenBaseModel, QWenModel


63
class QwenImagePixelInputs(TensorSchema):
64
    """
65
66
67
68
69
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
70

71
72
73
74
    Note that image_size is the value in the vision config to which we resize
    the image to in the normalization transform. Currently multi-image support
    can only be leveraged by passing image embeddings directly.
    """
75

76
77
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
78
79


80
81
82
83
84
85
class QwenImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size (256)
        - hs: Hidden size
86

87
88
89
    `hidden_size` must match the hidden size of the language model backbone
    and is stored in the visual config of the model if we have one.
    """
90

91
92
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")]
93
94


95
QwenImageInputs: TypeAlias = QwenImagePixelInputs | QwenImageEmbeddingInputs
96
97
98
99
100
101
102
103
104
105
106
107
108


class VisualAttention(nn.Module):
    """self-attention layer class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
109
110
        kdim: int | None = None,
        vdim: int | None = None,
111
112
113
114
115
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
116
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
117
118
119
120
121
122
123
124
125
126

        self.num_heads = num_heads

        # Per attention head and per partition values.
        assert embed_dim % num_heads == 0
        self.hidden_size_per_attention_head = embed_dim // num_heads
        self.num_attention_heads_per_partition = num_heads
        self.hidden_size_per_partition = embed_dim

        # Strided linear layer.
127
128
129
        assert self._qkv_same_embed_dim, (
            "Visual Attention implementation only supports self-attention"
        )
130
131
132
133
134
135
136
        self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
        self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

    def forward(
        self,
        x: torch.Tensor,
137
        attn_mask: torch.Tensor | None = None,
138
139
140
141
142
143
    ) -> torch.Tensor:
        # query/key/value: [sq, b, h]
        sq, b, _ = x.size()
        mixed_x_layer, _ = self.in_proj(x)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
144
145
146
147
        new_tensor_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads_per_partition,
            3 * self.hidden_size_per_attention_head,
        )
148
149
150
151
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        query_layer, key_layer, value_layer = mixed_x_layer.split(
152
153
            self.hidden_size_per_attention_head, dim=-1
        )
154
155
156

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(
157
158
159
160
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
161
162
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(
163
164
165
166
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
167
168
169

        q_scaled = query_layer / self.norm_factor
        if attn_mask is not None:
170
171
172
            attention_probs = torch.baddbmm(
                attn_mask, q_scaled, key_layer.transpose(-2, -1)
            )
173
174
175
176
177
        else:
            attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
        attention_probs = attention_probs.softmax(dim=-1)

        value_layer = value_layer.view(
178
179
180
181
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
182
183
184
185
186
187

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer)

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(
188
189
190
191
192
            b,
            self.num_attention_heads_per_partition,
            sq,
            self.hidden_size_per_attention_head,
        )
193
194
195
196
197

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
198
199
200
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size_per_partition,
        )
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        context_layer = context_layer.view(*new_context_layer_shape)

        output, _ = self.out_proj(context_layer)

        return output


class QwenVLMLP(nn.Module):
    """MLP for the visual component of the Qwen model."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
215
        quant_config: QuantizationConfig | None = None,
216
217
    ):
        super().__init__()
218
219
220
        self.c_fc = ColumnParallelLinear(
            hidden_size, intermediate_size, bias=True, quant_config=quant_config
        )
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        self.act_fn = get_act_fn("gelu")
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x):
        x, _ = self.c_fc(x)
        x = self.act_fn(x)
        x, _ = self.c_proj(x)
        return x


class VisualAttentionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
243
        quant_config: QuantizationConfig | None = None,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.attn = VisualAttention(d_model, n_head)
        self.mlp = QwenVLMLP(
            hidden_size=d_model,
            intermediate_size=mlp_width,
            quant_config=quant_config,
        )

    def attention(
        self,
        x: torch.Tensor,
260
        attn_mask: torch.Tensor | None = None,
261
262
263
264
265
266
267
    ) -> torch.Tensor:
        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
        return self.attn(x, attn_mask=attn_mask)

    def forward(
        self,
        x: torch.Tensor,
268
        attn_mask: torch.Tensor | None = None,
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    ) -> torch.Tensor:
        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x


class TransformerBlock(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
283
        quant_config: QuantizationConfig | None = None,
284
285
286
287
288
    ):
        super().__init__()
        self.width = width
        self.layers = layers

289
290
291
292
293
294
295
296
297
298
299
300
        self.resblocks = nn.ModuleList(
            [
                VisualAttentionBlock(
                    width,
                    heads,
                    mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                )
                for _ in range(layers)
            ]
        )
301
302
303
304
305
306
307

    def get_cast_dtype(self) -> torch.dtype:
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def get_cast_device(self) -> torch.device:
        return self.resblocks[0].mlp.c_fc.weight.device

308
    def forward(
309
        self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
310
    ) -> torch.Tensor:
311
312
313
314
315
316
        for r in self.resblocks:
            x = r(x, attn_mask=attn_mask)
        return x


class VisionTransformer(nn.Module):
317
318
319
320
321
322
323
324
325
326
327
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float,
        n_queries: int = 256,
        output_dim: int = 512,
        image_start_id: int = 151857,
328
        quant_config: QuantizationConfig | None = None,
329
330
        **kwargs,
    ):
331
332
333
        super().__init__()
        image_height, image_width = self.image_size = (image_size, image_size)
        patch_height, patch_width = self.patch_size = (patch_size, patch_size)
334
        self.grid_size = (image_height // patch_height, image_width // patch_width)
335
        self.output_dim = output_dim
336
337
338
339
340
341
342
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=width,
            kernel_size=patch_size,
            stride=patch_size,
            bias=False,
        )
343
344
345

        # class embeddings and positional embeddings
        scale = width**-0.5
346
        self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
347
348
349
350

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.ln_pre = norm_layer(width)
351
352
353
354
355
356
357
358
        self.transformer = TransformerBlock(
            width,
            layers,
            heads,
            mlp_ratio,
            norm_layer=norm_layer,
            quant_config=quant_config,
        )
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

        self.attn_pool = Resampler2(
            grid_size=int(math.sqrt(n_queries)),
            embed_dim=output_dim,
            num_heads=output_dim // 128,
            kv_dim=width,
            norm_layer=norm_layer,
            adaptive=False,
            do_post_projection=False,
        ).to(
            device=self.positional_embedding.device,
            dtype=self.positional_embedding.dtype,
        )

        self.ln_post = norm_layer(output_dim)
        self.proj = nn.Parameter(
375
376
            (output_dim**-0.5) * torch.randn(output_dim, output_dim)
        )
377
378
379
380
381
382
383
384
385
386
387
388
389

        self.image_start_id = image_start_id
        self.image_end_id = image_start_id + 1
        self.image_pad_id = image_start_id + 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(
            dtype=self.transformer.get_cast_dtype(),
            device=self.transformer.get_cast_device(),
        )

        # to patches
        x = self.conv1(x)  # shape = [*, width, grid, grid]
390
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
391
392
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

393
        x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(x.size(1))))
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.attn_pool(x)
        x = self.ln_post(x)
        x = x @ self.proj

        return x


class QwenVLModel(QWenModel):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

415
        self.visual = VisionTransformer(**config.visual, quant_config=quant_config)
416
417
418
419


@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
420
421
    tokenizer: PreTrainedTokenizer,
) -> PreTrainedTokenizer:
422
423
    """
    The logic of adding image pad tokens should only be applied in
424
425
    [`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
    so they are patched out here.
426
427
428
429
430
431
432
433
434
435

    The definition of the wrapped tokenizer can be found here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
    """
    new_tokenizer = copy.deepcopy(tokenizer)

    class TokenizerWithoutImagePad(tokenizer.__class__):  # type: ignore
        def tokenize(
            self,
            text: str,
436
437
            allowed_special: Set[str] | str = "all",
            disallowed_special: Collection[str] | str = (),
438
            **kwargs,
439
        ) -> list[bytes | str]:
440
441
442
            text = unicodedata.normalize("NFC", text)

            return [
443
444
                self.decoder[t]
                for t in self.tokenizer.encode(
445
446
447
448
449
450
451
452
                    text,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            ]

        def _decode(
            self,
453
            token_ids: int | list[int],
454
            skip_special_tokens: bool = False,
455
            errors: str | None = None,
456
457
458
459
460
461
462
463
464
465
            **kwargs,
        ) -> str:
            if isinstance(token_ids, int):
                token_ids = [token_ids]

            return self.tokenizer.decode(
                token_ids,
                errors=errors or self.errors,
            )

466
    TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

    new_tokenizer.__class__ = TokenizerWithoutImagePad
    return new_tokenizer


class QwenVLProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.

    We call the wrapped tokenizer to automatically insert image pad tokens:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245

    The image processor is defined here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
    """

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        vision_config = config.visual
        image_size = vision_config["image_size"]

497
498
499
500
501
502
503
504
505
506
507
508
509
        self.image_transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524

    @property
    def image_start_tag(self) -> str:
        return self.tokenizer.image_start_tag  # type: ignore

    @property
    def image_end_tag(self) -> str:
        return self.tokenizer.image_end_tag  # type: ignore

    @property
    def image_pad_tag(self) -> str:
        return self.tokenizer.image_pad_tag  # type: ignore

    def __call__(
        self,
525
526
527
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        text_inputs = self.tokenizer(text)

        if len(images) == 0:
            image_inputs = {}
        else:
            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )


class QwenVLProcessingInfo(BaseProcessingInfo):
    def get_tokenizer(self) -> PreTrainedTokenizer:
        tokenizer = self.ctx.tokenizer
        assert isinstance(tokenizer, PreTrainedTokenizer)

        return _get_tokenizer_without_image_pad(tokenizer)

562
563
564
565
566
567
568
    def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
        return self.ctx.init_processor(
            QwenVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )
569

570
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
571
572
573
574
575
576
577
578
579
580
581
582
583
        return {"image": None}

    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        vision_config = hf_config.visual

        image_size = vision_config["image_size"]
        patch_size = vision_config["patch_size"]
        grid_length = image_size // patch_size // 2
        return grid_length * grid_length


class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
584
585
586
587
588
589
590
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        img_start = hf_processor.image_start_tag
        img_end = hf_processor.image_end_tag

591
592
593
        return "".join(
            f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1)
        )
594
595

    def get_dummy_mm_data(
596
597
598
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
599
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
600
    ) -> MultiModalDataDict:
601
602
603
604
605
606
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.visual

        target_width = target_height = vision_config["image_size"]
        num_images = mm_counts.get("image", 0)

607
608
        image_overrides = mm_options.get("image") if mm_options else None

609
        return {
610
611
612
613
614
615
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
616
617
618
619
620
621
622
623
624
        }


class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
625
        tok_kwargs: Mapping[str, object],
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    ) -> BatchFeature:
        # Drops anything between <img>/</img> tags; encoding with the tokenizer
        # will automatically add the image pads for the context.
        prompt, num_matched_images = re.subn(
            r"(Picture \d*: <img>).*?(<\/img>\n)",
            r"\1\2",
            prompt,
        )

        image_data = mm_data.get("images")
        if image_data is not None:
            assert isinstance(image_data, list)

            num_images = len(image_data)
            assert num_matched_images == num_images

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
646
            tok_kwargs=tok_kwargs,
647
648
        )

649
    def _hf_processor_applies_updates(
650
651
652
653
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
654
        tokenization_kwargs: Mapping[str, object],
655
656
657
    ) -> bool:
        return False

658
659
660
661
662
663
664
665
666
667
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

668
    def _get_prompt_updates(
669
670
671
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
672
        out_mm_kwargs: MultiModalKwargsItems,
673
    ) -> Sequence[PromptUpdate]:
674
        tokenizer = self.info.get_tokenizer()
675
        special_tokens: dict[str, int] = tokenizer.special_tokens  # type: ignore
676
677
678
679
680
681
682
683
684
685
686
687
688

        processor = self.info.get_hf_processor()
        img_start_id = special_tokens[processor.image_start_tag]
        img_end_id = special_tokens[processor.image_end_tag]
        img_pad_id = special_tokens[processor.image_pad_tag]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [img_pad_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[img_start_id, img_end_id],
689
690
691
                replacement=PromptUpdateDetails.select_token_id(
                    [img_start_id] + image_tokens + [img_end_id],
                    embed_token_id=img_pad_id,
692
693
694
695
696
                ),
            )
        ]


697
698
699
700
701
702
703
704
@MULTIMODAL_REGISTRY.register_processor(
    QwenVLMultiModalProcessor,
    info=QwenVLProcessingInfo,
    dummy_inputs=QwenVLDummyInputsBuilder,
)
class QwenVLForConditionalGeneration(
    QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal
):
705
706
    merge_by_field_config = True

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.h",
            connector="transformer.visual.attn_pool",
722
723
            tower_model="transformer.visual.transformer",
        )
724

725
    @classmethod
726
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
727
728
729
730
731
        if modality.startswith("image"):
            return f"Picture {i}: <img></img>"

        raise ValueError("Only image modality is supported")

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QwenVLModel] = QwenVLModel,
    ) -> None:
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
            transformer_type=transformer_type,
        )

        self.transformer: QwenVLModel

    def _parse_and_validate_image_input(
748
        self, **kwargs: object
749
    ) -> QwenImageInputs | None:
750
751
752
753
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is not None:
754
755
756
            expected_h = expected_w = self.config.visual["image_size"]
            resolve_bindings = {"h": expected_h, "w": expected_w}

757
758
            return QwenImagePixelInputs(
                type="pixel_values",
759
                data=pixel_values,
760
                resolve_bindings=resolve_bindings,
761
762
763
764
765
            )

        if image_embeds is not None:
            return QwenImageEmbeddingInputs(
                type="image_embeds",
766
                data=image_embeds,
767
768
769
770
            )

        return None

771
    def _process_image_input(self, image_input: QwenImageInputs) -> torch.Tensor:
772
773
774
775
776
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        return self.transformer.visual(image_input["data"])

777
778
779
    def get_language_model(self) -> torch.nn.Module:
        return self.transformer

780
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
781
782
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
783
            return []
784
785
786
787
788
789
790
791

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
792
793
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
794
        **kwargs: object,
795
    ) -> torch.Tensor | IntermediateTensors:
796
797
798
        if intermediate_tensors is not None:
            inputs_embeds = None

799
800
801
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
802
        return hidden_states