qwen.py 13.2 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

gaoqiong's avatar
gaoqiong committed
13
14
15
import os
import re

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
29
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.sequence import IntermediateTensors, SamplerOutput
34
from vllm.utils import print_warning_once
Qing's avatar
Qing committed
35

gaoqiong's avatar
gaoqiong committed
36
from vllm import _custom_ops as ops
37
38
39
from vllm.model_executor.utils import pad_weight, gemm_bank_conf


40
41
42
43
44
45
46
class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
47
        quant_config: Optional[QuantizationConfig] = None,
48
49
50
51
52
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
53
            quant_config=quant_config)
54
55
56
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
57
                                        quant_config=quant_config)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
79
        cache_config: Optional[CacheConfig] = None,
80
        quant_config: Optional[QuantizationConfig] = None,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
96
            quant_config=quant_config,
97
98
99
100
101
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
102
            quant_config=quant_config,
103
104
105
106
107
108
109
110
111
112
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
113
114
115
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
116
117
                              cache_config=cache_config,
                              quant_config=quant_config)
118
119
120
121
122

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
123
124
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
125
126
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
127
        if os.environ.get('FA_PAD') == '1':
128
            qkv = qkv[...,:-32]
129
130
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
131
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
132
133
134
135
136
137
138
139
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
140
        config: PretrainedConfig,
141
        cache_config: Optional[CacheConfig] = None,
142
        quant_config: Optional[QuantizationConfig] = None,
143
144
145
146
147
148
149
150
151
152
153
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
154
                                  cache_config=cache_config,
155
                                  quant_config=quant_config)
156
157
158
159
160

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
161
                           quant_config=quant_config)
162
163
164
165
166

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
167
168
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
169
170
171
172
173
174
175
176
177
178
179
180
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
181
            attn_metadata=attn_metadata,
182
183
184
185
186
187
188
189
190
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
191

192
193
    def __init__(
        self,
194
        config: PretrainedConfig,
195
        cache_config: Optional[CacheConfig] = None,
196
        quant_config: Optional[QuantizationConfig] = None,
197
198
199
200
201
202
203
204
205
206
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.h = nn.ModuleList([
207
            QWenBlock(config, cache_config, quant_config)
208
209
210
211
212
213
214
215
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
216
217
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
218
219
220
221
222
223
224
225
226
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        residual = None
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
227
                attn_metadata,
228
229
230
231
232
233
234
235
236
237
                residual,
            )
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
238
        config: PretrainedConfig,
239
        cache_config: Optional[CacheConfig] = None,
240
        quant_config: Optional[QuantizationConfig] = None,
241
242
243
    ):
        super().__init__()
        self.config = config
244
        self.quant_config = quant_config
245
        self.transformer = QWenModel(config, cache_config, quant_config)
246
247
248
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
249
250
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
251
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
252
253
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
254
255
256
257
258

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
259
260
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
261
        intermediate_tensors: Optional[IntermediateTensors] = None,
262
263
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
264
                                         attn_metadata)
265
266
        return hidden_states

267
268
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
269
        logits = self.logits_processor(self.lm_head, hidden_states,
270
271
272
                                       sampling_metadata)
        return logits

273
274
    def sample(
        self,
275
        logits: torch.Tensor,
276
277
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
278
        next_tokens = self.sampler(logits, sampling_metadata)
279
        return next_tokens
Qing's avatar
Qing committed
280

281
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
282
283
284
285
286
287
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
288
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
289
290
            if "rotary_emb.inv_freq" in name:
                continue
291
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
292
293
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
294
295
296
297
298
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
299
300
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
301
                break
302
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
303
304
305
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
306
307
308
309
310
311
312
313
314
                # Skip loading visual weights to support Qwen-VL models
                # in cases with text-only inputs
                # TODO: add support for Qwen-VL
                if (name not in params_dict
                        and name.startswith("transformer.visual.")):
                    print_warning_once(
                        "Only text inputs are allowed. Images won't be handled "
                        "until Qwen-VL models are fully supported.")
                    continue
315
316
317
318
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
gaoqiong's avatar
gaoqiong committed
319
320
        if self.use_llama_nn:
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
321
322
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
323
                "mlp.gate_up_proj.weight",
324
325
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
326
327
328
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
329
330
331
332
333
334
            lay_qkv_words = ["attn.c_attn.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["attn.c_attn.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words) 
                      
gaoqiong's avatar
gaoqiong committed
335
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
336
337
338
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                
gaoqiong's avatar
gaoqiong committed
339
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
340
                if matches:         
341
342
343
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
344
345
346
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
347
                        
gaoqiong's avatar
gaoqiong committed
348
349
350
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
351
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
352
353
354
355
356
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)