qwen.py 16.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9
import json
10
11
from collections.abc import Iterable
from typing import Any, Optional, Union
12

13
14
import torch
from torch import nn
15
from transformers import PretrainedConfig
Qing's avatar
Qing committed
16

gaoqiong's avatar
gaoqiong committed
17
18
import os
import re
19
from vllm.attention import Attention
gaoqiong's avatar
gaoqiong committed
20

21
from vllm.compilation.decorators import support_torch_compile
22
from vllm.config import CacheConfig, VllmConfig
23
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
24
from vllm.model_executor.layers.activation import SiluAndMul
25
from vllm.model_executor.layers.layernorm import RMSNorm
26
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
27
28
                                               QKVParallelLinear,
                                               RowParallelLinear)
29
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.layers.rotary_embedding import get_rope
32
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
    ParallelLMHead, VocabParallelEmbedding)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
from vllm.model_executor.sampling_metadata import SamplingMetadata
36
from vllm.sequence import IntermediateTensors
Qing's avatar
Qing committed
37

38
39
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
40
                    make_empty_intermediate_tensors_factory, make_layers,
41
                    maybe_prefix)
gaoqiong's avatar
gaoqiong committed
42
from vllm import _custom_ops as ops
43
44
45
from vllm.model_executor.utils import pad_weight, gemm_bank_conf


46
class QWenMLP(nn.Module):
47
48
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
49
50
51
52
53
54

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
55
        quant_config: Optional[QuantizationConfig] = None,
56
57
58
59
60
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
61
            quant_config=quant_config)
62
63
64
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
65
                                        quant_config=quant_config)
66
67
68
69
70
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

71
    def forward(self, x: torch.Tensor) -> torch.Tensor:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
86
        rope_scaling: Optional[dict[str, Any]] = None,
87
        cache_config: Optional[CacheConfig] = None,
88
        quant_config: Optional[QuantizationConfig] = None,
89
        prefix: str = "",
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
105
            quant_config=quant_config,
106
107
108
109
110
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
111
            quant_config=quant_config,
112
113
114
115
116
117
118
119
120
121
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
122
123
124
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
125
                              cache_config=cache_config,
126
127
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
128
129
130
131
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
132
133
134
135
136
137
138

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
139
140
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
141
142
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
143
        attn_output = self.attn(q, k, v)
144
145
146
147
148
149
150
151
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
152
        config: PretrainedConfig,
153
        cache_config: Optional[CacheConfig] = None,
154
        quant_config: Optional[QuantizationConfig] = None,
155
        prefix: str = "",
156
157
158
159
160
161
162
163
164
165
166
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
167
                                  cache_config=cache_config,
168
169
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
170
171
172
173
174

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
175
                           quant_config=quant_config)
176
177
178
179
180
181

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
182
    ) -> tuple[torch.Tensor, torch.Tensor]:
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


200
@support_torch_compile
201
class QWenModel(nn.Module):
Qing's avatar
Qing committed
202

203
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
204
        super().__init__()
205
206
207
208
209

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

210
211
212
213
214
215
216
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
217
218
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
219
220
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
221
            prefix=f"{prefix}.h")
222
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
223
224
225
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
226

227
228
229
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

230
231
232
233
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
234
        intermediate_tensors: Optional[IntermediateTensors],
235
        inputs_embeds: Optional[torch.Tensor] = None,
236
    ) -> Union[torch.Tensor, IntermediateTensors]:
237
        if get_pp_group().is_first_rank:
238
239
240
241
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
242
243
244
245
246
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
247

248
        for layer in self.h[self.start_layer:self.end_layer]:
249
250
251
252
253
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
254
255
256
257
258
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
259
260
261
262
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


263
class QWenBaseModel(nn.Module):
264
265
266

    def __init__(
        self,
267
268
269
270
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
271
272
    ) -> None:
        super().__init__()
273
274
275
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
276
        self.config = config
277
        self.multimodal_config = multimodal_config
278
        self.quant_config = quant_config
279
280
281
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
282
283
284
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
285
286
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
287
        self.logits_processor = LogitsProcessor(config.vocab_size)
288
289
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
gaoqiong's avatar
gaoqiong committed
290
        
291
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
292
293
294
295
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
296
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
297
298
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
299
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
300

301
302
303
304
305
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
306
        logits = self.logits_processor(self.lm_head, hidden_states,
307
308
309
                                       sampling_metadata)
        return logits

310
311
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
312
313
314
315
316
317
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
318
        loaded_params: set[str] = set()
319
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
320
321
            if "rotary_emb.inv_freq" in name:
                continue
322
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
323
324
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
325
326
327
328
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
329
330
331
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
332
                param = params_dict[name]
333
334
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
335
                break
336
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
337
338
339
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
340
341
342
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
343
344
345
346
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
347
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
348

349
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
350
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
351
352
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
353
                "mlp.gate_up_proj.weight",
354
355
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
356
357
358
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
359
360
            # lay_qkv_words = ["attn.c_attn.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
361
            
zhuwenwen's avatar
zhuwenwen committed
362
363
            # lay_qkv_bias_words = ["attn.c_attn.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
zhuwenwen's avatar
zhuwenwen committed
364
                      
zhuwenwen's avatar
zhuwenwen committed
365
366
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
367
368
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
zhuwenwen's avatar
zhuwenwen committed
369
                
gaoqiong's avatar
gaoqiong committed
370
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
371
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
372
373
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
374
                        
zhuwenwen's avatar
zhuwenwen committed
375
376
377
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
378
                        
gaoqiong's avatar
gaoqiong committed
379
380
381
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
382
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
383
384
385
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
386
            
387
        return loaded_params
388
389


390
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
391
392
393
394
395
396
397
398
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

399
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
400
        config = vllm_config.model_config.hf_config
401
402
403
404
405
406
407
408
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
409
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
410
411

        super().__init__(vllm_config=vllm_config, prefix=prefix)
412

413
414
415
416
417
418
419
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
420
421
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
422
        return hidden_states