"tests/models/quantization/test_bitsandbytes.py" did not exist on "87525fab925edf549611a1a74a40699b0b5e316e"
qwen.py 11.5 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

gaoqiong's avatar
gaoqiong committed
13
14
15
import os
import re

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
29
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
33
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
Qing's avatar
Qing committed
34

gaoqiong's avatar
gaoqiong committed
35
from vllm import _custom_ops as ops
36
37
38
39
40
41
42
class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
43
        quant_config: Optional[QuantizationConfig] = None,
44
45
46
47
48
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
49
            quant_config=quant_config)
50
51
52
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
53
                                        quant_config=quant_config)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
75
        cache_config: Optional[CacheConfig] = None,
76
        quant_config: Optional[QuantizationConfig] = None,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
92
            quant_config=quant_config,
93
94
95
96
97
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
98
            quant_config=quant_config,
99
100
101
102
103
104
105
106
107
108
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
109
110
111
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
112
113
                              cache_config=cache_config,
                              quant_config=quant_config)
114
115
116
117
118

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
119
120
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
121
122
123
124
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
125
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
126
127
128
129
130
131
132
133
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
134
        config: PretrainedConfig,
135
        cache_config: Optional[CacheConfig] = None,
136
        quant_config: Optional[QuantizationConfig] = None,
137
138
139
140
141
142
143
144
145
146
147
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
148
                                  cache_config=cache_config,
149
                                  quant_config=quant_config)
150
151
152
153
154

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
155
                           quant_config=quant_config)
156
157
158
159
160

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
161
162
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
163
164
165
166
167
168
169
170
171
172
173
174
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
175
            attn_metadata=attn_metadata,
176
177
178
179
180
181
182
183
184
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
185

186
187
    def __init__(
        self,
188
        config: PretrainedConfig,
189
        cache_config: Optional[CacheConfig] = None,
190
        quant_config: Optional[QuantizationConfig] = None,
191
192
193
194
195
196
197
198
199
200
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.h = nn.ModuleList([
201
            QWenBlock(config, cache_config, quant_config)
202
203
204
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
gaoqiong's avatar
gaoqiong committed
205
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
206
207
208
209
210

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
211
212
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
213
214
215
216
217
218
219
220
221
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        residual = None
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
222
                attn_metadata,
223
224
225
226
227
228
229
230
231
232
                residual,
            )
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
233
        config: PretrainedConfig,
234
        cache_config: Optional[CacheConfig] = None,
235
        quant_config: Optional[QuantizationConfig] = None,
236
237
238
    ):
        super().__init__()
        self.config = config
239
        self.quant_config = quant_config
240
        self.transformer = QWenModel(config, cache_config, quant_config)
241
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
242
243
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
244
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
245
246
247
248
249

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
250
251
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
252
253
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
254
                                         attn_metadata)
255
256
        return hidden_states

257
258
259
260
261
262
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

263
264
    def sample(
        self,
265
        logits: torch.Tensor,
266
267
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
268
        next_tokens = self.sampler(logits, sampling_metadata)
269
        return next_tokens
Qing's avatar
Qing committed
270

271
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
272
273
274
275
276
277
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
278
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
279
280
            if "rotary_emb.inv_freq" in name:
                continue
281
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
282
283
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
284
285
286
287
288
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
289
290
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
291
                break
292
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
293
294
295
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
296
297
298
299
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
gaoqiong's avatar
gaoqiong committed
300
301
        if self.use_llama_nn:
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
302
303
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
304
                "mlp.gate_up_proj.weight",
zhuwenwen's avatar
zhuwenwen committed
305
                "mlp.c_proj.weight"
gaoqiong's avatar
gaoqiong committed
306
307
308
309
310
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
311
                if matches:         
gaoqiong's avatar
gaoqiong committed
312
313
314
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
315
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
316
317
318
319
320
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)