qwen.py 47.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Qing's avatar
Qing committed
3
4
5
6
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
7
"""Inference-only QWen model compatible with HuggingFace weights."""
Qing's avatar
Qing committed
8

9
import copy
10
11
import math
import re
12
13
14
15
16
import unicodedata
from functools import lru_cache, partial
from typing import (AbstractSet, Any, Callable, Collection, Dict, Iterable,
                    List, Literal, Mapping, Optional, Set, Tuple, TypedDict,
                    Union)
17

18
19
import torch
from torch import nn
20
21
from torchvision import transforms
from torchvision.transforms import InterpolationMode
22
23
24
25
from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
                          TensorType)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
Qing's avatar
Qing committed
26

gaoqiong's avatar
gaoqiong committed
27
28
29
import os
import re

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
35
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
36
from vllm.model_executor.layers.layernorm import RMSNorm
37
38
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
39
                                               QKVParallelLinear,
40
                                               ReplicatedLinear,
41
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
45
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
46
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
50
from vllm.model_executor.models.module_mapping import MultiModelKeys
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
53
54
55
56
57
58
59
60
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
Qing's avatar
Qing committed
61

62
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
63
from .utils import (flatten_bn, is_pp_missing_parameter,
64
                    make_empty_intermediate_tensors_factory, make_layers,
65
                    maybe_prefix, merge_multimodal_embeddings)
gaoqiong's avatar
gaoqiong committed
66
from vllm import _custom_ops as ops
67
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
zhuwenwen's avatar
zhuwenwen committed
68
from vllm.utils import W8a8GetCacheJSON
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
logger = init_logger(__name__)

# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
# for the time being, these tags are not considered as special at encoding
# time. This may change as VLLMs multimodal API changes in the future.
IMG_START = "<img>"
IMG_END = "</img>"
IMG_PAD = "<imgpad>"
# Image context is fixed at 256 for all images
MAX_QWEN_IMG_TOKENS = 256
# Image normalization params
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)


class QwenImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_images, 3, image_size, image_size)`

    Note that image_size is the value in the vision config to which we resize
    the image to in the normalization transform. Currently multi-image support
    can only be leveraged by passing image embeddings directly.
    """


class QwenImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, 256, hidden_size)`

    `hidden_size` must match the hidden size of the language model backbone
    and is stored in the visual config of the model if we have one.
    """


QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]


class VisualAttention(nn.Module):
    """self-attention layer class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        kdim: Optional[int] = None,
        vdim: Optional[int] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim \
            and self.vdim == embed_dim

        self.num_heads = num_heads

        # Per attention head and per partition values.
        assert embed_dim % num_heads == 0
        self.hidden_size_per_attention_head = embed_dim // num_heads
        self.num_attention_heads_per_partition = num_heads
        self.hidden_size_per_partition = embed_dim

        # Strided linear layer.
        assert self._qkv_same_embed_dim, \
                'Visual Attention implementation only supports self-attention'
142
143
        self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
        self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
144
145
146
147
148
149
150
151
152
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # query/key/value: [sq, b, h]
        sq, b, _ = x.size()
153
        mixed_x_layer, _ = self.in_proj(x)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
             3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        query_layer, key_layer, value_layer = mixed_x_layer.split(
            self.hidden_size_per_attention_head, dim=-1)

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)

        q_scaled = query_layer / self.norm_factor
        if attn_mask is not None:
            attention_probs = torch.baddbmm(attn_mask, q_scaled,
                                            key_layer.transpose(-2, -1))
        else:
            attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
        attention_probs = attention_probs.softmax(dim=-1)

        value_layer = value_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer)
Qing's avatar
Qing committed
188

189
190
191
192
193
194
195
196
197
198
199
200
201
        # change view [b, np, sq, hn]
        context_layer = context_layer.view(
            b, self.num_attention_heads_per_partition, sq,
            self.hidden_size_per_attention_head)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

202
        output, _ = self.out_proj(context_layer)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        return output


class QwenVMLP(nn.Module):
    """MLP for the visual component of the Qwen model."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.c_fc = ColumnParallelLinear(hidden_size,
                                         intermediate_size,
                                         bias=True,
                                         quant_config=quant_config)
221
        self.act_fn = get_act_fn("gelu")
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x):
        x, _ = self.c_fc(x)
        x = self.act_fn(x)
        x, _ = self.c_proj(x)
        return x


class VisualAttentionBlock(nn.Module):

    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
243
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.attn = VisualAttention(d_model, n_head)
        self.mlp = QwenVMLP(
            hidden_size=d_model,
            intermediate_size=mlp_width,
            quant_config=quant_config,
        )

    def attention(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
        return self.attn(x, attn_mask=attn_mask)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x


class TransformerBlock(nn.Module):

    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
284
        norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.width = width
        self.layers = layers

        self.resblocks = nn.ModuleList([
            VisualAttentionBlock(width,
                                 heads,
                                 mlp_ratio,
                                 norm_layer=norm_layer,
                                 quant_config=quant_config)
            for _ in range(layers)
        ])

    def get_cast_dtype(self) -> torch.dtype:
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def get_cast_device(self) -> torch.device:
        return self.resblocks[0].mlp.c_fc.weight.device

    def forward(self,
                x: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for r in self.resblocks:
            x = r(x, attn_mask=attn_mask)
        return x


class VisionTransformer(nn.Module):

    def __init__(self,
                 image_size: int,
                 patch_size: int,
                 width: int,
                 layers: int,
                 heads: int,
                 mlp_ratio: float,
                 n_queries: int = 256,
                 output_dim: int = 512,
                 image_start_id: int = 151857,
                 quant_config: Optional[QuantizationConfig] = None,
                 **kwargs):
        super().__init__()
        image_height, image_width = self.image_size = (image_size, image_size)
        patch_height, patch_width = self.patch_size = (patch_size, patch_size)
        self.grid_size = (image_height // patch_height,
                          image_width // patch_width)
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=width,
                               kernel_size=patch_size,
                               stride=patch_size,
                               bias=False)

        # class embeddings and positional embeddings
        scale = width**-0.5
        self.positional_embedding = nn.Parameter(scale *
                                                 torch.randn(256, width))

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.ln_pre = norm_layer(width)
        self.transformer = TransformerBlock(width,
                                            layers,
                                            heads,
                                            mlp_ratio,
                                            norm_layer=norm_layer,
                                            quant_config=quant_config)

        self.attn_pool = Resampler2(
            grid_size=int(math.sqrt(n_queries)),
            embed_dim=output_dim,
            num_heads=output_dim // 128,
            kv_dim=width,
            norm_layer=norm_layer,
            adaptive=False,
            do_post_projection=False,
        ).to(
            device=self.positional_embedding.device,
            dtype=self.positional_embedding.dtype,
        )

        self.ln_post = norm_layer(output_dim)
        self.proj = nn.Parameter(
            (output_dim**-0.5) * torch.randn(output_dim, output_dim))
371

372
373
        self.image_start_id = image_start_id
        self.image_end_id = image_start_id + 1
374
        self.image_pad_id = image_start_id + 2
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(
            dtype=self.transformer.get_cast_dtype(),
            device=self.transformer.get_cast_device(),
        )

        # to patches
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1],
                      -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

        x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
            x.size(1))))

        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.attn_pool(x)
        x = self.ln_post(x)
        x = x @ self.proj

        return x

403

404
class QWenMLP(nn.Module):
405
406
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
407
408
409
410
411
412

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
413
        quant_config: Optional[QuantizationConfig] = None,
414
415
416
417
418
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
419
            quant_config=quant_config)
420
421
422
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
423
                                        quant_config=quant_config)
424
425
426
427
428
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

429
    def forward(self, x: torch.Tensor) -> torch.Tensor:
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
445
        cache_config: Optional[CacheConfig] = None,
446
        quant_config: Optional[QuantizationConfig] = None,
447
        prefix: str = "",
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
463
            quant_config=quant_config,
464
465
466
467
468
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
469
            quant_config=quant_config,
470
471
472
473
474
475
476
477
478
479
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
480
481
482
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
483
                              cache_config=cache_config,
484
485
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
486
487
488
489
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
490
491
492
493
494

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
495
496
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
497
498
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
499
500
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
501
502
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
503
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
504
505
506
507
508
509
510
511
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
512
        config: PretrainedConfig,
513
        cache_config: Optional[CacheConfig] = None,
514
        quant_config: Optional[QuantizationConfig] = None,
515
        prefix: str = "",
516
517
518
519
520
521
522
523
524
525
526
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
527
                                  cache_config=cache_config,
528
529
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
530
531
532
533
534

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
535
                           quant_config=quant_config)
536
537
538
539
540

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
541
542
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
543
544
545
546
547
548
549
550
551
552
553
554
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
555
            attn_metadata=attn_metadata,
556
557
558
559
560
561
562
563
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


564
@support_torch_compile
565
class QWenModel(nn.Module):
Qing's avatar
Qing committed
566

567
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
568
        super().__init__()
569
570
571
572
573

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

574
575
576
577
578
579
580
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
581
582
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
583
584
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
585
            prefix=f"{prefix}.h")
586
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
587
588
589
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
590
591
592
593
594
595

        if (vision_config := getattr(config, "visual", None)):
            self.visual = VisionTransformer(**vision_config,
                                            quant_config=quant_config)
        else:
            self.visual = None
596

597
598
599
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

600
601
602
603
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
604
605
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
606
        intermediate_tensors: Optional[IntermediateTensors],
607
        inputs_embeds: Optional[torch.Tensor] = None,
608
    ) -> Union[torch.Tensor, IntermediateTensors]:
609
        if get_pp_group().is_first_rank:
610
611
612
613
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
614
615
616
617
618
619
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
620
621
622
623
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
624
                kv_caches[i - self.start_layer],
625
                attn_metadata,
626
627
                residual,
            )
628
629
630
631
632
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
633
634
635
636
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


637
def build_normalization_transform(image_size: int) -> transforms.Compose:
638
639
    """
    Build a normalization transform which can be applied to one or
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    more input images from which we want to extract visual features.

    Args:
        image_size: size of the image to be processed for visual embeddings.
    
    Returns:
        Callable transform for normalizing and resizing one RGB image.
    """
    return transforms.Compose([
        transforms.Resize((image_size, image_size),
                          interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
    ])


656
657
658
659
660
661
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
        tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
    """
    The logic of adding image pad tokens should only be applied in
    :class:`QWenVLProcessor`, so they are patched out here.
662

663
664
    The definition of the wrapped tokenizer can be found here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
665
    """
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    new_tokenizer = copy.deepcopy(tokenizer)

    class TokenizerWithoutImagePad(tokenizer.__class__):  # type: ignore

        def tokenize(
            self,
            text: str,
            allowed_special: Union[AbstractSet[str], str] = "all",
            disallowed_special: Union[Collection[str], str] = (),
            **kwargs,
        ) -> list[Union[bytes, str]]:
            text = unicodedata.normalize("NFC", text)

            return [
                self.decoder[t] for t in self.tokenizer.encode(
                    text,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            ]

        def _decode(
            self,
            token_ids: Union[int, List[int]],
            skip_special_tokens: bool = False,
            errors: Optional[str] = None,
            **kwargs,
        ) -> str:
            if isinstance(token_ids, int):
                token_ids = [token_ids]

            return self.tokenizer.decode(
                token_ids,
                errors=errors or self.errors,
            )

    TokenizerWithoutImagePad.__name__ = \
        f"{tokenizer.__class__.__name__}WithoutImagePad"

    new_tokenizer.__class__ = TokenizerWithoutImagePad
    return new_tokenizer


class QWenVLProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.

    We call the wrapped tokenizer to automatically insert image pad tokens:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245

    The image processor is defined here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
719
    """
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        if hasattr(self.config, "visual"):
            self.image_transform = build_normalization_transform(
                config.visual["image_size"])
        else:
            self.image_transform = None

        special_tokens: dict[str,
                             int] = tokenizer.special_tokens  # type: ignore
        self.img_start_id = special_tokens[IMG_START]
        self.img_end_id = special_tokens[IMG_END]

    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        images: Optional[Union[ImageInput, list[ImageInput]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        text_inputs = self.tokenizer(text)

        if len(images) == 0:
            image_inputs = {}
        else:
            if self.image_transform is None:
                raise ValueError("This model does not support image inputs")

            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )


class QWenVLProcessingInfo(BaseProcessingInfo):

    def get_tokenizer(self) -> PreTrainedTokenizer:
        tokenizer = self.ctx.tokenizer
        assert isinstance(tokenizer, PreTrainedTokenizer)

        return _get_tokenizer_without_image_pad(tokenizer)

    def get_hf_processor(self) -> QWenVLProcessor:
        tokenizer = self.ctx.tokenizer
        assert isinstance(tokenizer, PreTrainedTokenizer)

        return QWenVLProcessor(self.get_hf_config(), tokenizer)

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

794
795
796
797
798
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
        return {"image": self.get_num_image_tokens()}

    def get_num_image_tokens(self) -> int:
        return MAX_QWEN_IMG_TOKENS


class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        hf_config = self.info.get_hf_config()
        if not hasattr(hf_config, "visual"):
            return ProcessorInputs(prompt_text="", mm_data={})

        vision_config = hf_config.visual

818
        target_width = target_height = vision_config["image_size"]
819
820
821
822
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
823
824
            self._get_dummy_images(width=target_width,
                                   height=target_height,
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
                                for i in range(1, num_images + 1)),
            mm_data=mm_data,
        )


class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # Drops anything between <img>/</img> tags; encoding with the tokenizer
        # will automatically add the image pads for the context.
        prompt, num_matched_images = re.subn(
            r"(Picture \d*: <img>).*?(<\/img>\n)",
            r"\1\2",
            prompt,
        )

        image_data = mm_data.get("images")
        if image_data is not None:
            assert isinstance(image_data, list)

            num_images = len(image_data)
            if num_matched_images != num_images:
                logger.warning(
                    "Number of matched image placeholders %s doesn't match "
                    "the number of expected images %s; check your placeholder "
                    "formatting.", num_matched_images, num_images)

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
        tokenizer = self.info.get_tokenizer()
        special_tokens: dict[str,
                             int] = tokenizer.special_tokens  # type: ignore

        img_start_id = special_tokens[IMG_START]
        img_end_id = special_tokens[IMG_END]
        img_pad_id = special_tokens[IMG_PAD]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [img_pad_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[img_start_id, img_end_id],
                replacement=PromptReplacementDetails(
                    full=[img_start_id] + image_tokens + [img_end_id],
                    features=image_tokens,
                ),
            )
        ]
905
906


907
class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
908

909
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
910
        super().__init__()
911
912
913
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
914
        self.config = config
915
        self.multimodal_config = multimodal_config
916
        self.quant_config = quant_config
917
918
919
        self.transformer = QWenModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
920
921
922
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
923
924
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
925
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
926
        self.sampler = get_sampler()
927
928
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
gaoqiong's avatar
gaoqiong committed
929
        
930
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
931
932
933
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
zhuwenwen's avatar
zhuwenwen committed
934
        self.tritonsingleton= W8a8GetCacheJSON()
gaoqiong's avatar
gaoqiong committed
935
              
gaoqiong's avatar
gaoqiong committed
936
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
937
938
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
939
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
940
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
941

942
943
944
945
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.visual["image_size"]
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])
946

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[QwenImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            return QwenImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")

            return QwenImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds),
            )
980
981
982

        return None

983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    def _process_image_input(self,
                             image_input: QwenImageInputs) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.transformer.visual is not None
        return self.transformer.visual(image_input["data"])

    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.transformer.get_input_embeddings(input_ids)

        if multimodal_embeddings is not None:
            assert self.transformer.visual is not None
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.transformer.visual.image_pad_id)

        return inputs_embeds
1013

1014
1015
1016
1017
1018
1019
1020
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1021
        inputs_embeds: Optional[torch.Tensor] = None,
1022
        **kwargs: object,
1023
1024
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
1025
1026
1027
1028
1029
1030
1031
1032
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
1033
1034
            input_ids = None

1035
        hidden_states = self.transformer(input_ids, positions, kv_caches,
1036
                                         attn_metadata, intermediate_tensors,
1037
                                         inputs_embeds)
1038
1039
        return hidden_states

1040
1041
1042
1043
1044
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
1045
        logits = self.logits_processor(self.lm_head, hidden_states,
1046
1047
1048
                                       sampling_metadata)
        return logits

1049
1050
    def sample(
        self,
1051
        logits: torch.Tensor,
1052
1053
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
1054
        next_tokens = self.sampler(logits, sampling_metadata)
1055
        return next_tokens
Qing's avatar
Qing committed
1056

1057
1058
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1059
1060
1061
1062
1063
1064
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
1065
        loaded_params: Set[str] = set()
1066
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
1067
1068
            if "rotary_emb.inv_freq" in name:
                continue
1069
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
1070
1071
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
1072
1073
1074
1075
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1076
1077
1078
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
1079
                param = params_dict[name]
1080
1081
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
1082
                break
1083
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1084
1085
1086
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1087
1088
1089
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
1090
1091
1092
1093
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1094
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
1095

1096
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
1097
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
1098
1099
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
1100
                "mlp.gate_up_proj.weight",
1101
1102
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
1103
1104
1105
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
1106
1107
            # lay_qkv_words = ["attn.c_attn.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
1108
            
zhuwenwen's avatar
zhuwenwen committed
1109
1110
            # lay_qkv_bias_words = ["attn.c_attn.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
zhuwenwen's avatar
zhuwenwen committed
1111
                      
zhuwenwen's avatar
zhuwenwen committed
1112
1113
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
1114
1115
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
zhuwenwen's avatar
zhuwenwen committed
1116
                
gaoqiong's avatar
gaoqiong committed
1117
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
1118
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
1119
1120
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
1121
                        
zhuwenwen's avatar
zhuwenwen committed
1122
1123
1124
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
1125
                        
gaoqiong's avatar
gaoqiong committed
1126
1127
1128
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
1129
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
1130
1131
1132
1133
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
                    
zhuwenwen's avatar
zhuwenwen committed
1134
        if self.quant_method == "awq":
zhuwenwen's avatar
zhuwenwen committed
1135
            os.environ['LM_NN'] = '0'
gaoqiong's avatar
gaoqiong committed
1136
1137
1138
1139
1140
1141
1142
1143
            lay_key_words = [
                "attn.c_attn.qweight",
                "attn.c_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.c_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
1144
1145
            for layername in loaded_params:
                weight = params_dict[layername]
gaoqiong's avatar
gaoqiong committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
1160
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
1161
                    
gaoqiong's avatar
gaoqiong committed
1162
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
1163
1164
1165
1166
1167
1168
1169
1170
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
1171
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
1172
1173
1174
1175
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
1176
1177
                       
        if self.quant_method == "compressed_tensors":
1178
            os.environ['LM_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
1179
1180
1181
1182
1183
1184
1185
            lay_key_words = [
                "attn.c_attn.weight",
                "attn.c_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.c_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
1186
1187
            weight_shapes=[]
            all_json={}
zhuwenwen's avatar
zhuwenwen committed
1188
            
zhuwenwen's avatar
zhuwenwen committed
1189
1190
            for layername in loaded_params:
                weight = params_dict[layername] 
zhuwenwen's avatar
zhuwenwen committed
1191
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
1192
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
1193
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
                    n=weight_data.shape[0]
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
                    elif len(weight_shapes)<4: 
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
1208
1209
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
1210
1211
1212
1213
1214
1215
1216
1217
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
zhuwenwen's avatar
zhuwenwen committed
1218
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
1219
            
1220
        return loaded_params
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241


class QWenLLM(QWenBaseModel):
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "c_attn",
        "gate_up_proj",
        "c_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []


1242
class QWenVL(QWenBaseModel, SupportsMultiModal):
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "c_attn",
        "gate_up_proj",
        "c_proj",
        # visual module
        "out_proj",
        "in_proj",
        "c_fc",
        # resampler
        "kv_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.h",
            connector="transformer.visual.attn_pool",
            tower_model="transformer.visual.transformer")


1276
1277
1278
@MULTIMODAL_REGISTRY.register_processor(QWenVLMultiModalProcessor,
                                        info=QWenVLProcessingInfo,
                                        dummy_inputs=QWenVLDummyInputsBuilder)
1279
class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
1280
1281
1282
1283
1284
1285
1286
    """
    QWenLMHeadModel is not only applicable to LLM  but also to VL, which is not 
    conducive to the current integration logic of LoRA in vLLM. Therefore, it 
    is necessary to separate them.
    """
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
1287
    # These will be updated when an instance class is selected
1288
1289
1290
1291
1292
1293
1294
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

    def __new__(
        cls,
1295
1296
        vllm_config: VllmConfig,
        prefix: str = "",
1297
    ) -> QWenBaseModel:
1298
        config = vllm_config.model_config.hf_config
1299

1300
        # Initialize VL
1301
1302
        if hasattr(config, "visual"):  # noqa: SIM108
            instance_cls = QWenVL
1303
1304
        # Initialize LLM
        else:
1305
1306
1307
1308
1309
1310
1311
1312
1313
            instance_cls = QWenLLM

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.supported_lora_modules += instance_cls.supported_lora_modules
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)