qwen.py 17.2 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

gaoqiong's avatar
gaoqiong committed
13
14
15
import os
import re

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig
18
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
29
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.sequence import IntermediateTensors, SamplerOutput
34
from vllm.utils import print_warning_once
Qing's avatar
Qing committed
35

gaoqiong's avatar
gaoqiong committed
36
from vllm import _custom_ops as ops
37
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
38
from .utils import is_pp_missing_parameter, make_layers
39
40


41
42
43
44
45
46
47
class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
48
        quant_config: Optional[QuantizationConfig] = None,
49
50
51
52
53
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
54
            quant_config=quant_config)
55
56
57
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
58
                                        quant_config=quant_config)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
80
        cache_config: Optional[CacheConfig] = None,
81
        quant_config: Optional[QuantizationConfig] = None,
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
97
            quant_config=quant_config,
98
99
100
101
102
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
103
            quant_config=quant_config,
104
105
106
107
108
109
110
111
112
113
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
114
115
116
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
117
118
                              cache_config=cache_config,
                              quant_config=quant_config)
119
120
121
122
123
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
124
125
126
127
128

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
129
130
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
131
132
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
133
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
134
            qkv = qkv[...,:-32]
135
136
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
137
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
138
139
140
141
142
143
144
145
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
146
        config: PretrainedConfig,
147
        cache_config: Optional[CacheConfig] = None,
148
        quant_config: Optional[QuantizationConfig] = None,
149
150
151
152
153
154
155
156
157
158
159
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
160
                                  cache_config=cache_config,
161
                                  quant_config=quant_config)
162
163
164
165
166

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
167
                           quant_config=quant_config)
168
169
170
171
172

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
173
174
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
175
176
177
178
179
180
181
182
183
184
185
186
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
187
            attn_metadata=attn_metadata,
188
189
190
191
192
193
194
195
196
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
197

198
199
    def __init__(
        self,
200
        config: PretrainedConfig,
201
        cache_config: Optional[CacheConfig] = None,
202
        quant_config: Optional[QuantizationConfig] = None,
203
        prefix: str = "",
204
205
206
207
208
209
210
211
212
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
213
214
215
216
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: QWenBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h")
217
218
219
220
221
222
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
223
224
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
225
        intermediate_tensors: Optional[IntermediateTensors],
226
    ) -> torch.Tensor:
227
228
229
230
231
232
233
234
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
235
236
237
238
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
239
                kv_caches[i - self.start_layer],
240
                attn_metadata,
241
242
                residual,
            )
243
244
245
246
247
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
248
249
250
251
252
253
254
255
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
256
        config: PretrainedConfig,
257
        cache_config: Optional[CacheConfig] = None,
258
        quant_config: Optional[QuantizationConfig] = None,
259
260
261
    ):
        super().__init__()
        self.config = config
262
        self.quant_config = quant_config
263
        self.transformer = QWenModel(config, cache_config, quant_config)
264
265
266
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
267
268
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
269
        
270
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
271
272
273
274
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
275
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
276
277
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
278
279
280
281
282

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
283
284
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
285
        intermediate_tensors: Optional[IntermediateTensors] = None,
286
287
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
288
                                         attn_metadata, intermediate_tensors)
289
290
        return hidden_states

291
292
293
294
295
296
297
298
299
300
301
302
303
304
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

305
306
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
307
        logits = self.logits_processor(self.lm_head, hidden_states,
308
309
310
                                       sampling_metadata)
        return logits

311
312
    def sample(
        self,
313
        logits: torch.Tensor,
314
315
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
316
        next_tokens = self.sampler(logits, sampling_metadata)
317
        return next_tokens
Qing's avatar
Qing committed
318

319
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
320
321
322
323
324
325
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
326
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
327
328
            if "rotary_emb.inv_freq" in name:
                continue
329
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
330
331
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
332
333
334
335
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
336
337
338
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
339
                param = params_dict[name]
340
341
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
342
                break
343
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
344
345
346
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
347
348
349
350
351
352
353
354
355
                # Skip loading visual weights to support Qwen-VL models
                # in cases with text-only inputs
                # TODO: add support for Qwen-VL
                if (name not in params_dict
                        and name.startswith("transformer.visual.")):
                    print_warning_once(
                        "Only text inputs are allowed. Images won't be handled "
                        "until Qwen-VL models are fully supported.")
                    continue
356
357
358
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
359
360
361
362
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
363
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
364
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
365
366
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
367
                "mlp.gate_up_proj.weight",
368
369
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
370
371
372
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
373
374
375
376
377
378
            lay_qkv_words = ["attn.c_attn.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["attn.c_attn.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words) 
                      
gaoqiong's avatar
gaoqiong committed
379
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
380
381
382
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                
gaoqiong's avatar
gaoqiong committed
383
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
384
                if matches:         
385
386
387
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
388
389
390
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
391
                        
gaoqiong's avatar
gaoqiong committed
392
393
394
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
395
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
396
397
398
399
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
                    
zhuwenwen's avatar
zhuwenwen committed
400
        if self.quant_method == "awq":
gaoqiong's avatar
gaoqiong committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
            lay_key_words = [
                "attn.c_attn.qweight",
                "attn.c_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.c_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
424
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
425
                    
gaoqiong's avatar
gaoqiong committed
426
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()