qwen.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9
import json
10
from collections.abc import Iterable
11
from itertools import islice
12
from typing import Any, Optional, Union
13

14
15
import torch
from torch import nn
16
from transformers import PretrainedConfig
Qing's avatar
Qing committed
17

gaoqiong's avatar
gaoqiong committed
18
19
import os
import re
20
from vllm.attention import Attention
gaoqiong's avatar
gaoqiong committed
21

22
from vllm.compilation.decorators import support_torch_compile
23
from vllm.config import CacheConfig, VllmConfig
24
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
25
from vllm.model_executor.layers.activation import SiluAndMul
26
from vllm.model_executor.layers.layernorm import RMSNorm
27
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
28
29
                                               QKVParallelLinear,
                                               RowParallelLinear)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.vocab_parallel_embedding import (
34
    ParallelLMHead, VocabParallelEmbedding)
35
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
from vllm.model_executor.sampling_metadata import SamplingMetadata
37
from vllm.sequence import IntermediateTensors
Qing's avatar
Qing committed
38

39
40
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
41
                    make_empty_intermediate_tensors_factory, make_layers,
42
                    maybe_prefix)
gaoqiong's avatar
gaoqiong committed
43
from vllm import _custom_ops as ops
44
45
46
from vllm.model_executor.utils import pad_weight, gemm_bank_conf


47
class QWenMLP(nn.Module):
48
49
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
50
51
52
53
54
55

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
56
        quant_config: Optional[QuantizationConfig] = None,
57
58
59
60
61
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
62
            quant_config=quant_config)
63
64
65
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
66
                                        quant_config=quant_config)
67
68
69
70
71
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

72
    def forward(self, x: torch.Tensor) -> torch.Tensor:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
87
        rope_scaling: Optional[dict[str, Any]] = None,
88
        cache_config: Optional[CacheConfig] = None,
89
        quant_config: Optional[QuantizationConfig] = None,
90
        prefix: str = "",
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
106
            quant_config=quant_config,
107
108
109
110
111
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
112
            quant_config=quant_config,
113
114
115
116
117
118
119
120
121
122
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
123
124
125
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
126
                              cache_config=cache_config,
127
128
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
129
130
131
132
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
133
134
135
136
137
138
139

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
140
141
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
142
143
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
144
        attn_output = self.attn(q, k, v)
145
146
147
148
149
150
151
152
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
153
        config: PretrainedConfig,
154
        cache_config: Optional[CacheConfig] = None,
155
        quant_config: Optional[QuantizationConfig] = None,
156
        prefix: str = "",
157
158
159
160
161
162
163
164
165
166
167
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
168
                                  cache_config=cache_config,
169
170
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
171
172
173
174
175

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
176
                           quant_config=quant_config)
177
178
179
180
181
182

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
183
    ) -> tuple[torch.Tensor, torch.Tensor]:
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


201
@support_torch_compile
202
class QWenModel(nn.Module):
Qing's avatar
Qing committed
203

204
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
205
        super().__init__()
206
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

211
212
213
214
215
216
217
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
218
219
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
220
221
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
222
            prefix=f"{prefix}.h")
223
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
224
225
226
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
227

228
229
230
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

231
232
233
234
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
235
        intermediate_tensors: Optional[IntermediateTensors],
236
        inputs_embeds: Optional[torch.Tensor] = None,
237
    ) -> Union[torch.Tensor, IntermediateTensors]:
238
        if get_pp_group().is_first_rank:
239
240
241
242
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
243
244
245
246
247
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
248

249
        for layer in islice(self.h, self.start_layer, self.end_layer):
250
251
252
253
254
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
255
256
257
258
259
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
260
261
262
263
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


264
class QWenBaseModel(nn.Module):
265
266
267

    def __init__(
        self,
268
269
270
271
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
272
273
    ) -> None:
        super().__init__()
274
275
276
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
277
        self.config = config
278
        self.multimodal_config = multimodal_config
279
        self.quant_config = quant_config
280
281
282
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
283
284
285
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
286
287
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
288
        self.logits_processor = LogitsProcessor(config.vocab_size)
289
290
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
gaoqiong's avatar
gaoqiong committed
291
        
292
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
293
294
295
296
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
297
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
298
299
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
300
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
301

302
303
304
305
306
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
307
        logits = self.logits_processor(self.lm_head, hidden_states,
308
309
310
                                       sampling_metadata)
        return logits

311
312
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
313
314
315
316
317
318
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
319
        loaded_params: set[str] = set()
320
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
321
322
            if "rotary_emb.inv_freq" in name:
                continue
323
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
324
325
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
326
327
328
329
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
330
331
332
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
333
                param = params_dict[name]
334
335
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
336
                break
337
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
338
339
340
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
341
342
343
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
344
345
346
347
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
348
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
349

350
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
351
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
352
353
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
354
                "mlp.gate_up_proj.weight",
355
356
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
357
358
359
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
360
361
            # lay_qkv_words = ["attn.c_attn.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
362
            
zhuwenwen's avatar
zhuwenwen committed
363
364
            # lay_qkv_bias_words = ["attn.c_attn.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
zhuwenwen's avatar
zhuwenwen committed
365
                      
zhuwenwen's avatar
zhuwenwen committed
366
367
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
368
369
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
zhuwenwen's avatar
zhuwenwen committed
370
                
gaoqiong's avatar
gaoqiong committed
371
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
372
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
373
374
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
375
                        
zhuwenwen's avatar
zhuwenwen committed
376
377
378
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
379
                        
gaoqiong's avatar
gaoqiong committed
380
381
382
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
383
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
384
385
386
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
387
            
388
        return loaded_params
389
390


391
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
392
393
394
395
396
397
398
399
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

400
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
401
        config = vllm_config.model_config.hf_config
402
403
404
405
406
407
408
409
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
410
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
411
412

        super().__init__(vllm_config=vllm_config, prefix=prefix)
413

414
415
416
417
418
419
420
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
421
422
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
423
        return hidden_states