qwen.py 12.2 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

gaoqiong's avatar
gaoqiong committed
13
14
15
import os
import re

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
29
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.sequence import IntermediateTensors, SamplerOutput
34
from vllm.utils import print_warning_once
Qing's avatar
Qing committed
35

gaoqiong's avatar
gaoqiong committed
36
from vllm import _custom_ops as ops
37
38
39
40
41
42
43
class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
44
        quant_config: Optional[QuantizationConfig] = None,
45
46
47
48
49
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
50
            quant_config=quant_config)
51
52
53
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
54
                                        quant_config=quant_config)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
76
        cache_config: Optional[CacheConfig] = None,
77
        quant_config: Optional[QuantizationConfig] = None,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
93
            quant_config=quant_config,
94
95
96
97
98
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
99
            quant_config=quant_config,
100
101
102
103
104
105
106
107
108
109
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
110
111
112
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
113
114
                              cache_config=cache_config,
                              quant_config=quant_config)
115
116
117
118
119

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
120
121
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
122
123
124
125
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
126
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
127
128
129
130
131
132
133
134
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
135
        config: PretrainedConfig,
136
        cache_config: Optional[CacheConfig] = None,
137
        quant_config: Optional[QuantizationConfig] = None,
138
139
140
141
142
143
144
145
146
147
148
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
149
                                  cache_config=cache_config,
150
                                  quant_config=quant_config)
151
152
153
154
155

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
156
                           quant_config=quant_config)
157
158
159
160
161

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
162
163
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
164
165
166
167
168
169
170
171
172
173
174
175
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
176
            attn_metadata=attn_metadata,
177
178
179
180
181
182
183
184
185
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
186

187
188
    def __init__(
        self,
189
        config: PretrainedConfig,
190
        cache_config: Optional[CacheConfig] = None,
191
        quant_config: Optional[QuantizationConfig] = None,
192
193
194
195
196
197
198
199
200
201
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.h = nn.ModuleList([
202
            QWenBlock(config, cache_config, quant_config)
203
204
205
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
gaoqiong's avatar
gaoqiong committed
206
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
207
208
209
210
211

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
212
213
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
214
215
216
217
218
219
220
221
222
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        residual = None
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
223
                attn_metadata,
224
225
226
227
228
229
230
231
232
233
                residual,
            )
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
234
        config: PretrainedConfig,
235
        cache_config: Optional[CacheConfig] = None,
236
        quant_config: Optional[QuantizationConfig] = None,
237
238
239
    ):
        super().__init__()
        self.config = config
240
        self.quant_config = quant_config
241
        self.transformer = QWenModel(config, cache_config, quant_config)
242
243
244
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
245
246
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
247
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
248
249
250
251
252

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
253
254
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
255
        intermediate_tensors: Optional[IntermediateTensors] = None,
256
257
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
258
                                         attn_metadata)
259
260
        return hidden_states

261
262
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
263
        logits = self.logits_processor(self.lm_head, hidden_states,
264
265
266
                                       sampling_metadata)
        return logits

267
268
    def sample(
        self,
269
        logits: torch.Tensor,
270
271
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
272
        next_tokens = self.sampler(logits, sampling_metadata)
273
        return next_tokens
Qing's avatar
Qing committed
274

275
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
276
277
278
279
280
281
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
282
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
283
284
            if "rotary_emb.inv_freq" in name:
                continue
285
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
286
287
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
288
289
290
291
292
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
293
294
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
295
                break
296
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
297
298
299
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
300
301
302
303
304
305
306
307
308
                # Skip loading visual weights to support Qwen-VL models
                # in cases with text-only inputs
                # TODO: add support for Qwen-VL
                if (name not in params_dict
                        and name.startswith("transformer.visual.")):
                    print_warning_once(
                        "Only text inputs are allowed. Images won't be handled "
                        "until Qwen-VL models are fully supported.")
                    continue
309
310
311
312
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
gaoqiong's avatar
gaoqiong committed
313
314
        if self.use_llama_nn:
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
315
316
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
317
                "mlp.gate_up_proj.weight",
zhuwenwen's avatar
zhuwenwen committed
318
                "mlp.c_proj.weight"
gaoqiong's avatar
gaoqiong committed
319
320
321
322
323
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
324
                if matches:         
gaoqiong's avatar
gaoqiong committed
325
326
327
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
328
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
329
330
331
332
333
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)