qwen_vl.py 21.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9

# Adapted from
# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""

import math
10
11
from collections.abc import Callable, Mapping, Sequence
from functools import partial
12
from typing import Annotated, Literal, TypeAlias
13

14
import regex as re
15
16
import torch
from torch import nn
17
from transformers import BatchFeature
18
19

from vllm.config import VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions
21
from vllm.inputs import MultiModalDataDict
22
from vllm.model_executor.layers.activation import get_act_fn
23
from vllm.model_executor.layers.conv import Conv2dLayer
24
25
26
27
28
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
29
30
31
32
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
33
34
35
36
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
37
from vllm.multimodal.parse import MultiModalDataItems
38
from vllm.multimodal.processing import (
39
    BaseDummyInputsBuilder,
40
41
42
43
44
45
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
46
from vllm.sequence import IntermediateTensors
47
48
49
50
from vllm.transformers_utils.processors.qwen_vl import (
    QwenVLImageProcessorFast,
    QwenVLProcessor,
)
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
54
55
56
57
58
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
59
from .qwen import QWenBaseModel, QWenBlock, QWenModel
60
61


62
class QwenImagePixelInputs(TensorSchema):
63
    """
64
65
66
67
68
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
69

70
71
72
73
    Note that image_size is the value in the vision config to which we resize
    the image to in the normalization transform. Currently multi-image support
    can only be leveraged by passing image embeddings directly.
    """
74

75
76
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
77
78


79
80
81
82
83
84
class QwenImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size (256)
        - hs: Hidden size
85

86
87
88
    `hidden_size` must match the hidden size of the language model backbone
    and is stored in the visual config of the model if we have one.
    """
89

90
91
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")]
92
93


94
QwenImageInputs: TypeAlias = QwenImagePixelInputs | QwenImageEmbeddingInputs
95
96
97
98
99
100
101
102
103
104
105
106
107


class VisualAttention(nn.Module):
    """self-attention layer class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
108
109
        kdim: int | None = None,
        vdim: int | None = None,
110
        prefix: str = "",
111
112
113
114
115
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
116
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
117
118
119
120
121
122
123
124
125
126

        self.num_heads = num_heads

        # Per attention head and per partition values.
        assert embed_dim % num_heads == 0
        self.hidden_size_per_attention_head = embed_dim // num_heads
        self.num_attention_heads_per_partition = num_heads
        self.hidden_size_per_partition = embed_dim

        # Strided linear layer.
127
128
129
        assert self._qkv_same_embed_dim, (
            "Visual Attention implementation only supports self-attention"
        )
130
131
132
133
134
135
        self.in_proj = ReplicatedLinear(
            embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj"
        )
        self.out_proj = ReplicatedLinear(
            embed_dim, embed_dim, prefix=f"{prefix}.out_proj"
        )
136
137
138
139
140
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

    def forward(
        self,
        x: torch.Tensor,
141
        attn_mask: torch.Tensor | None = None,
142
143
144
145
146
147
    ) -> torch.Tensor:
        # query/key/value: [sq, b, h]
        sq, b, _ = x.size()
        mixed_x_layer, _ = self.in_proj(x)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
148
149
150
151
        new_tensor_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads_per_partition,
            3 * self.hidden_size_per_attention_head,
        )
152
153
154
155
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        query_layer, key_layer, value_layer = mixed_x_layer.split(
156
157
            self.hidden_size_per_attention_head, dim=-1
        )
158
159
160

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(
161
162
163
164
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
165
166
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(
167
168
169
170
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
171
172
173

        q_scaled = query_layer / self.norm_factor
        if attn_mask is not None:
174
175
176
            attention_probs = torch.baddbmm(
                attn_mask, q_scaled, key_layer.transpose(-2, -1)
            )
177
178
179
180
181
        else:
            attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
        attention_probs = attention_probs.softmax(dim=-1)

        value_layer = value_layer.view(
182
183
184
185
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
186
187
188
189
190
191

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer)

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(
192
193
194
195
196
            b,
            self.num_attention_heads_per_partition,
            sq,
            self.hidden_size_per_attention_head,
        )
197
198
199
200
201

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
202
203
204
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size_per_partition,
        )
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        context_layer = context_layer.view(*new_context_layer_shape)

        output, _ = self.out_proj(context_layer)

        return output


class QwenVLMLP(nn.Module):
    """MLP for the visual component of the Qwen model."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
219
        quant_config: QuantizationConfig | None = None,
220
        prefix: str = "",
221
222
    ):
        super().__init__()
223
        self.c_fc = ColumnParallelLinear(
224
225
226
227
228
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_fc",
229
        )
230
231
232
233
234
235
        self.act_fn = get_act_fn("gelu")
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
236
            prefix=f"{prefix}.c_proj",
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        )

    def forward(self, x):
        x, _ = self.c_fc(x)
        x = self.act_fn(x)
        x, _ = self.c_proj(x)
        return x


class VisualAttentionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
253
        quant_config: QuantizationConfig | None = None,
254
        prefix: str = "",
255
256
257
258
259
260
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
261
        self.attn = VisualAttention(d_model, n_head, prefix=f"{prefix}.attn")
262
263
264
265
        self.mlp = QwenVLMLP(
            hidden_size=d_model,
            intermediate_size=mlp_width,
            quant_config=quant_config,
266
            prefix=f"{prefix}.mlp",
267
268
269
270
271
        )

    def attention(
        self,
        x: torch.Tensor,
272
        attn_mask: torch.Tensor | None = None,
273
274
275
276
277
278
279
    ) -> torch.Tensor:
        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
        return self.attn(x, attn_mask=attn_mask)

    def forward(
        self,
        x: torch.Tensor,
280
        attn_mask: torch.Tensor | None = None,
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    ) -> torch.Tensor:
        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x


class TransformerBlock(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
295
        quant_config: QuantizationConfig | None = None,
296
        prefix: str = "",
297
298
299
300
301
    ):
        super().__init__()
        self.width = width
        self.layers = layers

302
303
304
305
306
307
308
309
        self.resblocks = nn.ModuleList(
            [
                VisualAttentionBlock(
                    width,
                    heads,
                    mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
310
                    prefix=f"{prefix}.resblocks.{i}",
311
                )
312
                for i in range(layers)
313
314
            ]
        )
315
316
317
318
319
320
321

    def get_cast_dtype(self) -> torch.dtype:
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def get_cast_device(self) -> torch.device:
        return self.resblocks[0].mlp.c_fc.weight.device

322
    def forward(
323
        self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
324
    ) -> torch.Tensor:
325
326
327
328
329
330
        for r in self.resblocks:
            x = r(x, attn_mask=attn_mask)
        return x


class VisionTransformer(nn.Module):
331
332
333
334
335
336
337
338
339
340
341
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float,
        n_queries: int = 256,
        output_dim: int = 512,
        image_start_id: int = 151857,
342
        quant_config: QuantizationConfig | None = None,
343
        prefix: str = "",
344
345
        **kwargs,
    ):
346
347
348
        super().__init__()
        image_height, image_width = self.image_size = (image_size, image_size)
        patch_height, patch_width = self.patch_size = (patch_size, patch_size)
349
        self.grid_size = (image_height // patch_height, image_width // patch_width)
350
        self.output_dim = output_dim
351
        self.conv1 = Conv2dLayer(
352
353
354
355
356
357
            in_channels=3,
            out_channels=width,
            kernel_size=patch_size,
            stride=patch_size,
            bias=False,
        )
358
359
360

        # class embeddings and positional embeddings
        scale = width**-0.5
361
        self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
362
363
364
365

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.ln_pre = norm_layer(width)
366
367
368
369
370
371
372
        self.transformer = TransformerBlock(
            width,
            layers,
            heads,
            mlp_ratio,
            norm_layer=norm_layer,
            quant_config=quant_config,
373
            prefix=f"{prefix}.transformer",
374
        )
375
376
377
378
379
380
381
382
383

        self.attn_pool = Resampler2(
            grid_size=int(math.sqrt(n_queries)),
            embed_dim=output_dim,
            num_heads=output_dim // 128,
            kv_dim=width,
            norm_layer=norm_layer,
            adaptive=False,
            do_post_projection=False,
384
            prefix=f"{prefix}.attn_pool",
385
386
387
388
389
390
391
        ).to(
            device=self.positional_embedding.device,
            dtype=self.positional_embedding.dtype,
        )

        self.ln_post = norm_layer(output_dim)
        self.proj = nn.Parameter(
392
393
            (output_dim**-0.5) * torch.randn(output_dim, output_dim)
        )
394
395
396
397
398
399
400
401
402
403
404
405
406

        self.image_start_id = image_start_id
        self.image_end_id = image_start_id + 1
        self.image_pad_id = image_start_id + 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(
            dtype=self.transformer.get_cast_dtype(),
            device=self.transformer.get_cast_device(),
        )

        # to patches
        x = self.conv1(x)  # shape = [*, width, grid, grid]
407
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
408
409
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

410
        x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(x.size(1))))
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431

        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.attn_pool(x)
        x = self.ln_post(x)
        x = x @ self.proj

        return x


class QwenVLModel(QWenModel):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

432
433
434
        self.visual = VisionTransformer(
            **config.visual, quant_config=quant_config, prefix=f"{prefix}.visual"
        )
435
436


437
class QwenVLProcessingInfo(BaseProcessingInfo):
438
    def get_image_processor(self, **kwargs):
439
        config = self.get_hf_config()
440
        vision_config = config.visual
441

442
        image_size = vision_config["image_size"]
443
        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
444
        kwargs.setdefault("size", {"width": image_size, "height": image_size})
445

446
447
448
        return QwenVLImageProcessorFast(**kwargs)

    def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
449
        return QwenVLProcessor(
450
            tokenizer=self.get_tokenizer(),
451
            image_processor=self.get_image_processor(**kwargs),
452
        )
453

454
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
455
456
457
458
459
460
461
462
463
464
465
466
467
        return {"image": None}

    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        vision_config = hf_config.visual

        image_size = vision_config["image_size"]
        patch_size = vision_config["patch_size"]
        grid_length = image_size // patch_size // 2
        return grid_length * grid_length


class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
468
469
470
471
472
473
474
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        img_start = hf_processor.image_start_tag
        img_end = hf_processor.image_end_tag

475
476
477
        return "".join(
            f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1)
        )
478
479

    def get_dummy_mm_data(
480
481
482
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
483
        mm_options: Mapping[str, BaseDummyOptions],
484
    ) -> MultiModalDataDict:
485
486
487
488
489
490
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.visual

        target_width = target_height = vision_config["image_size"]
        num_images = mm_counts.get("image", 0)

491
        image_overrides = mm_options.get("image")
492

493
        return {
494
495
496
497
498
499
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
500
501
502
503
504
505
506
507
508
        }


class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
509
        tok_kwargs: Mapping[str, object],
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    ) -> BatchFeature:
        # Drops anything between <img>/</img> tags; encoding with the tokenizer
        # will automatically add the image pads for the context.
        prompt, num_matched_images = re.subn(
            r"(Picture \d*: <img>).*?(<\/img>\n)",
            r"\1\2",
            prompt,
        )

        image_data = mm_data.get("images")
        if image_data is not None:
            assert isinstance(image_data, list)

            num_images = len(image_data)
            assert num_matched_images == num_images

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
530
            tok_kwargs=tok_kwargs,
531
532
        )

533
    def _hf_processor_applies_updates(
534
535
536
537
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
538
        tokenization_kwargs: Mapping[str, object],
539
540
541
    ) -> bool:
        return False

542
543
544
545
546
547
548
549
550
551
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

552
    def _get_prompt_updates(
553
554
555
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
556
        out_mm_kwargs: MultiModalKwargsItems,
557
    ) -> Sequence[PromptUpdate]:
558
        tokenizer = self.info.get_tokenizer()
559
        special_tokens: dict[str, int] = tokenizer.special_tokens  # type: ignore
560
561
562
563
564
565
566
567
568
569
570
571
572

        processor = self.info.get_hf_processor()
        img_start_id = special_tokens[processor.image_start_tag]
        img_end_id = special_tokens[processor.image_end_tag]
        img_pad_id = special_tokens[processor.image_pad_tag]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [img_pad_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[img_start_id, img_end_id],
573
574
575
                replacement=PromptUpdateDetails.select_token_id(
                    [img_start_id] + image_tokens + [img_end_id],
                    embed_token_id=img_pad_id,
576
577
578
579
580
                ),
            )
        ]


581
582
583
584
585
586
587
588
@MULTIMODAL_REGISTRY.register_processor(
    QwenVLMultiModalProcessor,
    info=QwenVLProcessingInfo,
    dummy_inputs=QwenVLDummyInputsBuilder,
)
class QwenVLForConditionalGeneration(
    QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal
):
589
590
591
592
593
594
595
596
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

597
598
    embed_input_ids = SupportsMultiModal.embed_input_ids

599
600
601
602
603
604
605
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.h",
            connector="transformer.visual.attn_pool",
606
607
            tower_model="transformer.visual.transformer",
        )
608

609
    @classmethod
610
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
611
612
613
614
615
        if modality.startswith("image"):
            return f"Picture {i}: <img></img>"

        raise ValueError("Only image modality is supported")

616
617
618
619
620
621
622
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QwenVLModel] = QwenVLModel,
    ) -> None:
623
624
625
626
627
628
629
630
631
632
        with self._mark_composite_model(
            vllm_config,
            language_targets=QWenBlock,
            tower_targets={"image": VisionTransformer},
        ):
            super().__init__(
                vllm_config=vllm_config,
                prefix=prefix,
                transformer_type=transformer_type,
            )
633
634
635
636

        self.transformer: QwenVLModel

    def _parse_and_validate_image_input(
637
        self, **kwargs: object
638
    ) -> QwenImageInputs | None:
639
640
641
642
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is not None:
643
644
645
            expected_h = expected_w = self.config.visual["image_size"]
            resolve_bindings = {"h": expected_h, "w": expected_w}

646
647
            return QwenImagePixelInputs(
                type="pixel_values",
648
                data=pixel_values,
649
                resolve_bindings=resolve_bindings,
650
651
652
653
654
            )

        if image_embeds is not None:
            return QwenImageEmbeddingInputs(
                type="image_embeds",
655
                data=image_embeds,
656
657
658
659
            )

        return None

660
    def _process_image_input(self, image_input: QwenImageInputs) -> torch.Tensor:
661
662
663
664
665
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        return self.transformer.visual(image_input["data"])

666
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
667
668
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
669
            return []
670
671
672
673
674
675

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
676
        input_ids: torch.Tensor | None,
677
        positions: torch.Tensor,
678
679
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
680
        **kwargs: object,
681
    ) -> torch.Tensor | IntermediateTensors:
682
683
684
        if intermediate_tensors is not None:
            inputs_embeds = None

685
686
687
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
688
        return hidden_states