qwen.py 14.8 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

gaoqiong's avatar
gaoqiong committed
13
14
15
import os
import re

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
29
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
33
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
Qing's avatar
Qing committed
34

gaoqiong's avatar
gaoqiong committed
35
from vllm import _custom_ops as ops
36
37
38
from vllm.model_executor.utils import pad_weight, gemm_bank_conf


39
40
41
42
43
44
45
class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
46
        quant_config: Optional[QuantizationConfig] = None,
47
48
49
50
51
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
52
            quant_config=quant_config)
53
54
55
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
56
                                        quant_config=quant_config)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
78
        cache_config: Optional[CacheConfig] = None,
79
        quant_config: Optional[QuantizationConfig] = None,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
95
            quant_config=quant_config,
96
97
98
99
100
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
101
            quant_config=quant_config,
102
103
104
105
106
107
108
109
110
111
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
112
113
114
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
115
116
                              cache_config=cache_config,
                              quant_config=quant_config)
117
118
119
120
121

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
122
123
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
124
125
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
126
        if os.environ.get('FA_PAD') == '1':
127
            qkv = qkv[...,:-32]
128
129
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
130
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
131
132
133
134
135
136
137
138
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
139
        config: PretrainedConfig,
140
        cache_config: Optional[CacheConfig] = None,
141
        quant_config: Optional[QuantizationConfig] = None,
142
143
144
145
146
147
148
149
150
151
152
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
153
                                  cache_config=cache_config,
154
                                  quant_config=quant_config)
155
156
157
158
159

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
160
                           quant_config=quant_config)
161
162
163
164
165

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
166
167
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
168
169
170
171
172
173
174
175
176
177
178
179
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
180
            attn_metadata=attn_metadata,
181
182
183
184
185
186
187
188
189
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
190

191
192
    def __init__(
        self,
193
        config: PretrainedConfig,
194
        cache_config: Optional[CacheConfig] = None,
195
        quant_config: Optional[QuantizationConfig] = None,
196
197
198
199
200
201
202
203
204
205
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.h = nn.ModuleList([
206
            QWenBlock(config, cache_config, quant_config)
207
208
209
210
211
212
213
214
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
215
216
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
217
218
219
220
221
222
223
224
225
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        residual = None
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
226
                attn_metadata,
227
228
229
230
231
232
233
234
235
236
                residual,
            )
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
237
        config: PretrainedConfig,
238
        cache_config: Optional[CacheConfig] = None,
239
        quant_config: Optional[QuantizationConfig] = None,
240
241
242
    ):
        super().__init__()
        self.config = config
243
        self.quant_config = quant_config
244
        self.transformer = QWenModel(config, cache_config, quant_config)
245
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
246
247
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
248
249
250
251
252
253
        
        self.quant_method =  None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
254
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
255
256
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
257
258
259
260
261

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
262
263
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
264
265
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
266
                                         attn_metadata)
267
268
        return hidden_states

269
270
271
272
273
274
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

275
276
    def sample(
        self,
277
        logits: torch.Tensor,
278
279
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
280
        next_tokens = self.sampler(logits, sampling_metadata)
281
        return next_tokens
Qing's avatar
Qing committed
282

283
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
284
285
286
287
288
289
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
290
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
291
292
            if "rotary_emb.inv_freq" in name:
                continue
293
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
294
295
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
296
297
298
299
300
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
301
302
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
303
                break
304
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
305
306
307
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
308
309
310
311
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
gaoqiong's avatar
gaoqiong committed
312
313
        if self.use_llama_nn:
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
314
315
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
316
                "mlp.gate_up_proj.weight",
zhuwenwen's avatar
zhuwenwen committed
317
                "mlp.c_proj.weight"
gaoqiong's avatar
gaoqiong committed
318
319
320
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
321
322
323
324
325
326
            lay_qkv_words = ["attn.c_attn.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["attn.c_attn.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words) 
                      
gaoqiong's avatar
gaoqiong committed
327
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
328
329
330
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                
gaoqiong's avatar
gaoqiong committed
331
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
332
                if matches:         
333
334
335
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
336
337
338
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
339
                        
gaoqiong's avatar
gaoqiong committed
340
341
342
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
343
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
344
345
346
347
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
                    
zhuwenwen's avatar
zhuwenwen committed
348
        if self.quant_method == "awq":
gaoqiong's avatar
gaoqiong committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            lay_key_words = [
                "attn.c_attn.qweight",
                "attn.c_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.c_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
372
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
373
                    
gaoqiong's avatar
gaoqiong committed
374
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()