qwen.py 40.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Qing's avatar
Qing committed
3
4
5
6
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
7
"""Inference-only QWen model compatible with HuggingFace weights."""
Qing's avatar
Qing committed
8

9
import copy
10
11
import math
import re
12
13
14
15
16
import unicodedata
from functools import lru_cache, partial
from typing import (AbstractSet, Any, Callable, Collection, Dict, Iterable,
                    List, Literal, Mapping, Optional, Set, Tuple, TypedDict,
                    Union)
17

18
19
import torch
from torch import nn
20
21
from torchvision import transforms
from torchvision.transforms import InterpolationMode
22
23
24
25
from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
                          TensorType)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
Qing's avatar
Qing committed
26

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
31
32
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
33
from vllm.model_executor.layers.layernorm import RMSNorm
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
36
                                               QKVParallelLinear,
37
                                               ReplicatedLinear,
38
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
42
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
43
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    ParallelLMHead, VocabParallelEmbedding)
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47
from vllm.model_executor.models.module_mapping import MultiModelKeys
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
50
51
52
53
54
55
56
57
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
58

59
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
60
from .utils import (flatten_bn, is_pp_missing_parameter,
61
                    make_empty_intermediate_tensors_factory, make_layers,
62
                    maybe_prefix, merge_multimodal_embeddings)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

logger = init_logger(__name__)


class QwenImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_images, 3, image_size, image_size)`

    Note that image_size is the value in the vision config to which we resize
    the image to in the normalization transform. Currently multi-image support
    can only be leveraged by passing image embeddings directly.
    """


class QwenImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, 256, hidden_size)`

    `hidden_size` must match the hidden size of the language model backbone
    and is stored in the visual config of the model if we have one.
    """


QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]


class VisualAttention(nn.Module):
    """self-attention layer class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        kdim: Optional[int] = None,
        vdim: Optional[int] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim \
            and self.vdim == embed_dim

        self.num_heads = num_heads

        # Per attention head and per partition values.
        assert embed_dim % num_heads == 0
        self.hidden_size_per_attention_head = embed_dim // num_heads
        self.num_attention_heads_per_partition = num_heads
        self.hidden_size_per_partition = embed_dim

        # Strided linear layer.
        assert self._qkv_same_embed_dim, \
                'Visual Attention implementation only supports self-attention'
124
125
        self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
        self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
126
127
128
129
130
131
132
133
134
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # query/key/value: [sq, b, h]
        sq, b, _ = x.size()
135
        mixed_x_layer, _ = self.in_proj(x)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
             3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        query_layer, key_layer, value_layer = mixed_x_layer.split(
            self.hidden_size_per_attention_head, dim=-1)

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)

        q_scaled = query_layer / self.norm_factor
        if attn_mask is not None:
            attention_probs = torch.baddbmm(attn_mask, q_scaled,
                                            key_layer.transpose(-2, -1))
        else:
            attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
        attention_probs = attention_probs.softmax(dim=-1)

        value_layer = value_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer)
Qing's avatar
Qing committed
170

171
172
173
174
175
176
177
178
179
180
181
182
183
        # change view [b, np, sq, hn]
        context_layer = context_layer.view(
            b, self.num_attention_heads_per_partition, sq,
            self.hidden_size_per_attention_head)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

184
        output, _ = self.out_proj(context_layer)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

        return output


class QwenVMLP(nn.Module):
    """MLP for the visual component of the Qwen model."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.c_fc = ColumnParallelLinear(hidden_size,
                                         intermediate_size,
                                         bias=True,
                                         quant_config=quant_config)
203
        self.act_fn = get_act_fn("gelu")
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x):
        x, _ = self.c_fc(x)
        x = self.act_fn(x)
        x, _ = self.c_proj(x)
        return x


class VisualAttentionBlock(nn.Module):

    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
225
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.attn = VisualAttention(d_model, n_head)
        self.mlp = QwenVMLP(
            hidden_size=d_model,
            intermediate_size=mlp_width,
            quant_config=quant_config,
        )

    def attention(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
        return self.attn(x, attn_mask=attn_mask)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x


class TransformerBlock(nn.Module):

    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
266
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.width = width
        self.layers = layers

        self.resblocks = nn.ModuleList([
            VisualAttentionBlock(width,
                                 heads,
                                 mlp_ratio,
                                 norm_layer=norm_layer,
                                 quant_config=quant_config)
            for _ in range(layers)
        ])

    def get_cast_dtype(self) -> torch.dtype:
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def get_cast_device(self) -> torch.device:
        return self.resblocks[0].mlp.c_fc.weight.device

    def forward(self,
                x: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for r in self.resblocks:
            x = r(x, attn_mask=attn_mask)
        return x


class VisionTransformer(nn.Module):

    def __init__(self,
                 image_size: int,
                 patch_size: int,
                 width: int,
                 layers: int,
                 heads: int,
                 mlp_ratio: float,
                 n_queries: int = 256,
                 output_dim: int = 512,
                 image_start_id: int = 151857,
                 quant_config: Optional[QuantizationConfig] = None,
                 **kwargs):
        super().__init__()
        image_height, image_width = self.image_size = (image_size, image_size)
        patch_height, patch_width = self.patch_size = (patch_size, patch_size)
        self.grid_size = (image_height // patch_height,
                          image_width // patch_width)
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=width,
                               kernel_size=patch_size,
                               stride=patch_size,
                               bias=False)

        # class embeddings and positional embeddings
        scale = width**-0.5
        self.positional_embedding = nn.Parameter(scale *
                                                 torch.randn(256, width))

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.ln_pre = norm_layer(width)
        self.transformer = TransformerBlock(width,
                                            layers,
                                            heads,
                                            mlp_ratio,
                                            norm_layer=norm_layer,
                                            quant_config=quant_config)

        self.attn_pool = Resampler2(
            grid_size=int(math.sqrt(n_queries)),
            embed_dim=output_dim,
            num_heads=output_dim // 128,
            kv_dim=width,
            norm_layer=norm_layer,
            adaptive=False,
            do_post_projection=False,
        ).to(
            device=self.positional_embedding.device,
            dtype=self.positional_embedding.dtype,
        )

        self.ln_post = norm_layer(output_dim)
        self.proj = nn.Parameter(
            (output_dim**-0.5) * torch.randn(output_dim, output_dim))
353

354
355
        self.image_start_id = image_start_id
        self.image_end_id = image_start_id + 1
356
        self.image_pad_id = image_start_id + 2
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(
            dtype=self.transformer.get_cast_dtype(),
            device=self.transformer.get_cast_device(),
        )

        # to patches
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1],
                      -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

        x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
            x.size(1))))

        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.attn_pool(x)
        x = self.ln_post(x)
        x = x @ self.proj

        return x

385
386

class QWenMLP(nn.Module):
387
388
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
389
390
391
392
393
394

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
395
        quant_config: Optional[QuantizationConfig] = None,
396
397
398
399
400
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
401
            quant_config=quant_config)
402
403
404
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
405
                                        quant_config=quant_config)
406
407
408
409
410
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

411
    def forward(self, x: torch.Tensor) -> torch.Tensor:
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
427
        cache_config: Optional[CacheConfig] = None,
428
        quant_config: Optional[QuantizationConfig] = None,
429
        prefix: str = "",
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
445
            quant_config=quant_config,
446
447
448
449
450
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
451
            quant_config=quant_config,
452
453
454
455
456
457
458
459
460
461
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
462
463
464
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
465
                              cache_config=cache_config,
466
467
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
468
469
470
471
472

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
473
474
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
475
476
477
478
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
479
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
480
481
482
483
484
485
486
487
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
488
        config: PretrainedConfig,
489
        cache_config: Optional[CacheConfig] = None,
490
        quant_config: Optional[QuantizationConfig] = None,
491
        prefix: str = "",
492
493
494
495
496
497
498
499
500
501
502
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
503
                                  cache_config=cache_config,
504
505
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
506
507
508
509
510

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
511
                           quant_config=quant_config)
512
513
514
515
516

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
517
518
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
519
520
521
522
523
524
525
526
527
528
529
530
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
531
            attn_metadata=attn_metadata,
532
533
534
535
536
537
538
539
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


540
@support_torch_compile
541
class QWenModel(nn.Module):
Qing's avatar
Qing committed
542

543
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
544
        super().__init__()
545
546
547
548
549

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

550
551
552
553
554
555
556
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
557
558
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
559
560
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
561
            prefix=f"{prefix}.h")
562
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
563
564
565
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
566
567
568
569
570
571

        if (vision_config := getattr(config, "visual", None)):
            self.visual = VisionTransformer(**vision_config,
                                            quant_config=quant_config)
        else:
            self.visual = None
572

573
574
575
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

576
577
578
579
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
580
581
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
582
        intermediate_tensors: Optional[IntermediateTensors],
583
        inputs_embeds: Optional[torch.Tensor] = None,
584
    ) -> Union[torch.Tensor, IntermediateTensors]:
585
        if get_pp_group().is_first_rank:
586
587
588
589
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
590
591
592
593
594
595
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
596
597
598
599
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
600
                kv_caches[i - self.start_layer],
601
                attn_metadata,
602
603
                residual,
            )
604
605
606
607
608
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
609
610
611
612
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


613
614
615
616
617
618
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
        tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
    """
    The logic of adding image pad tokens should only be applied in
    :class:`QWenVLProcessor`, so they are patched out here.
619

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    The definition of the wrapped tokenizer can be found here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
    """
    new_tokenizer = copy.deepcopy(tokenizer)

    class TokenizerWithoutImagePad(tokenizer.__class__):  # type: ignore

        def tokenize(
            self,
            text: str,
            allowed_special: Union[AbstractSet[str], str] = "all",
            disallowed_special: Union[Collection[str], str] = (),
            **kwargs,
        ) -> list[Union[bytes, str]]:
            text = unicodedata.normalize("NFC", text)

            return [
                self.decoder[t] for t in self.tokenizer.encode(
                    text,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            ]

        def _decode(
            self,
            token_ids: Union[int, List[int]],
            skip_special_tokens: bool = False,
            errors: Optional[str] = None,
            **kwargs,
        ) -> str:
            if isinstance(token_ids, int):
                token_ids = [token_ids]

            return self.tokenizer.decode(
                token_ids,
                errors=errors or self.errors,
            )

    TokenizerWithoutImagePad.__name__ = \
        f"{tokenizer.__class__.__name__}WithoutImagePad"

    new_tokenizer.__class__ = TokenizerWithoutImagePad
    return new_tokenizer


class QWenVLProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.

    We call the wrapped tokenizer to automatically insert image pad tokens:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245

    The image processor is defined here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
676
    """
677
678
679
680
681
682
683
684
685
686
687

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

688
689
690
691
692
693
694
695
696
697
698
699
700
701
        if vision_config := getattr(self.config, "visual", None):
            image_size = vision_config["image_size"]

            self.image_transform = transforms.Compose([
                transforms.Resize(
                    (image_size, image_size),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ])
702
703
704
        else:
            self.image_transform = None

705
706
707
708
709
710
711
712
713
714
715
    @property
    def image_start_tag(self) -> str:
        return self.tokenizer.image_start_tag  # type: ignore

    @property
    def image_end_tag(self) -> str:
        return self.tokenizer.image_end_tag  # type: ignore

    @property
    def image_pad_tag(self) -> str:
        return self.tokenizer.image_pad_tag  # type: ignore
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768

    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        images: Optional[Union[ImageInput, list[ImageInput]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        text_inputs = self.tokenizer(text)

        if len(images) == 0:
            image_inputs = {}
        else:
            if self.image_transform is None:
                raise ValueError("This model does not support image inputs")

            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )


class QWenVLProcessingInfo(BaseProcessingInfo):

    def get_tokenizer(self) -> PreTrainedTokenizer:
        tokenizer = self.ctx.tokenizer
        assert isinstance(tokenizer, PreTrainedTokenizer)

        return _get_tokenizer_without_image_pad(tokenizer)

    def get_hf_processor(self) -> QWenVLProcessor:
        tokenizer = self.ctx.tokenizer
        assert isinstance(tokenizer, PreTrainedTokenizer)

        return QWenVLProcessor(self.get_hf_config(), tokenizer)

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

769
770
771
772
773
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
774
775
776
        return {"image": self.get_num_image_tokens()}

    def get_num_image_tokens(self) -> int:
777
778
779
780
781
782
783
784
        hf_config = self.get_hf_config()
        if not (vision_config := getattr(hf_config, "visual", None)):
            return 0

        image_size = vision_config["image_size"]
        patch_size = vision_config["patch_size"]
        grid_length = image_size // patch_size // 2
        return grid_length * grid_length
785
786
787
788
789
790
791
792
793
794


class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        hf_config = self.info.get_hf_config()
795
        if not (vision_config := getattr(hf_config, "visual", None)):
796
797
            return ProcessorInputs(prompt_text="", mm_data={})

798
799
800
        processor = self.info.get_hf_processor()
        img_start = processor.image_start_tag
        img_end = processor.image_end_tag
801

802
        target_width = target_height = vision_config["image_size"]
803
804
805
806
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
807
808
            self._get_dummy_images(width=target_width,
                                   height=target_height,
809
810
811
812
                                   num_images=num_images)
        }

        return ProcessorInputs(
813
            prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
                                for i in range(1, num_images + 1)),
            mm_data=mm_data,
        )


class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # Drops anything between <img>/</img> tags; encoding with the tokenizer
        # will automatically add the image pads for the context.
        prompt, num_matched_images = re.subn(
            r"(Picture \d*: <img>).*?(<\/img>\n)",
            r"\1\2",
            prompt,
        )

        image_data = mm_data.get("images")
        if image_data is not None:
            assert isinstance(image_data, list)

            num_images = len(image_data)
            if num_matched_images != num_images:
                logger.warning(
                    "Number of matched image placeholders %s doesn't match "
                    "the number of expected images %s; check your placeholder "
                    "formatting.", num_matched_images, num_images)

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
868
869
870
871
        hf_config = self.info.get_hf_config()
        if not hasattr(hf_config, "visual"):
            return []

872
873
874
875
        tokenizer = self.info.get_tokenizer()
        special_tokens: dict[str,
                             int] = tokenizer.special_tokens  # type: ignore

876
877
878
879
        processor = self.info.get_hf_processor()
        img_start_id = special_tokens[processor.image_start_tag]
        img_end_id = special_tokens[processor.image_end_tag]
        img_pad_id = special_tokens[processor.image_pad_tag]
880
881
882
883
884
885
886
887
888
889
890
891
892
893

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [img_pad_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[img_start_id, img_end_id],
                replacement=PromptReplacementDetails(
                    full=[img_start_id] + image_tokens + [img_end_id],
                    features=image_tokens,
                ),
            )
        ]
894
895


896
class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
897

898
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
899
        super().__init__()
900
901
902
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
903
        self.config = config
904
        self.multimodal_config = multimodal_config
905
        self.quant_config = quant_config
906
907
908
        self.transformer = QWenModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
909
910
911
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
912
913
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
914
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
915
        self.sampler = get_sampler()
916
917
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
918

919
920
921
922
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.visual["image_size"]
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])
923

924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[QwenImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            return QwenImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")

            return QwenImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds),
            )
957
958
959

        return None

960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
    def _process_image_input(self,
                             image_input: QwenImageInputs) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.transformer.visual is not None
        return self.transformer.visual(image_input["data"])

    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.transformer.get_input_embeddings(input_ids)

        if multimodal_embeddings is not None:
            assert self.transformer.visual is not None
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.transformer.visual.image_pad_id)

        return inputs_embeds
990

991
992
993
994
995
996
997
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
998
        inputs_embeds: Optional[torch.Tensor] = None,
999
        **kwargs: object,
1000
1001
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
1002
1003
1004
1005
1006
1007
1008
1009
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
1010
1011
            input_ids = None

1012
        hidden_states = self.transformer(input_ids, positions, kv_caches,
1013
                                         attn_metadata, intermediate_tensors,
1014
                                         inputs_embeds)
1015
1016
        return hidden_states

1017
1018
1019
1020
1021
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
1022
        logits = self.logits_processor(self.lm_head, hidden_states,
1023
1024
1025
                                       sampling_metadata)
        return logits

1026
1027
    def sample(
        self,
1028
        logits: torch.Tensor,
1029
1030
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
1031
        next_tokens = self.sampler(logits, sampling_metadata)
1032
        return next_tokens
Qing's avatar
Qing committed
1033

1034
1035
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1036
1037
1038
1039
1040
1041
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
1042
        loaded_params: Set[str] = set()
1043
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
1044
1045
            if "rotary_emb.inv_freq" in name:
                continue
1046
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
1047
1048
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
1049
1050
1051
1052
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1053
1054
1055
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
1056
                param = params_dict[name]
1057
1058
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
1059
                break
1060
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1061
1062
1063
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1064
1065
1066
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
1067
1068
1069
1070
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1071
1072
            loaded_params.add(name)
        return loaded_params
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093


class QWenLLM(QWenBaseModel):
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "c_attn",
        "gate_up_proj",
        "c_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []


1094
class QWenVL(QWenBaseModel, SupportsMultiModal):
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "c_attn",
        "gate_up_proj",
        "c_proj",
        # visual module
        "out_proj",
        "in_proj",
        "c_fc",
        # resampler
        "kv_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.h",
            connector="transformer.visual.attn_pool",
            tower_model="transformer.visual.transformer")


1128
1129
1130
@MULTIMODAL_REGISTRY.register_processor(QWenVLMultiModalProcessor,
                                        info=QWenVLProcessingInfo,
                                        dummy_inputs=QWenVLDummyInputsBuilder)
1131
class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
1132
1133
1134
1135
1136
1137
1138
    """
    QWenLMHeadModel is not only applicable to LLM  but also to VL, which is not 
    conducive to the current integration logic of LoRA in vLLM. Therefore, it 
    is necessary to separate them.
    """
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
1139
    # These will be updated when an instance class is selected
1140
1141
1142
1143
1144
1145
1146
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

    def __new__(
        cls,
1147
1148
        vllm_config: VllmConfig,
        prefix: str = "",
1149
    ) -> QWenBaseModel:
1150
        config = vllm_config.model_config.hf_config
1151

1152
        # Initialize VL
1153
1154
        if hasattr(config, "visual"):  # noqa: SIM108
            instance_cls = QWenVL
1155
1156
        # Initialize LLM
        else:
1157
1158
1159
1160
1161
1162
1163
1164
1165
            instance_cls = QWenLLM

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.supported_lora_modules += instance_cls.supported_lora_modules
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)