qwen.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9

10
import json
11
from collections.abc import Iterable
12
from itertools import islice
13
from typing import Any
14

15
16
import torch
from torch import nn
17
from transformers import PretrainedConfig
Qing's avatar
Qing committed
18

gaoqiong's avatar
gaoqiong committed
19
20
21
import os
import re

22
from vllm.attention.layer import Attention
23
from vllm.compilation.decorators import support_torch_compile
24
from vllm.config import CacheConfig, VllmConfig
25
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
26
from vllm.model_executor.layers.activation import SiluAndMul
27
from vllm.model_executor.layers.layernorm import RMSNorm
28
29
30
31
32
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.rotary_embedding import get_rope
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
38
39
    ParallelLMHead,
    VocabParallelEmbedding,
)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.sequence import IntermediateTensors
Qing's avatar
Qing committed
42

43
from .interfaces import SupportsLoRA, SupportsPP
44
45
46
47
48
49
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
gaoqiong's avatar
gaoqiong committed
50
from vllm import _custom_ops as ops
51
52
53
from vllm.model_executor.utils import pad_weight, gemm_bank_conf


54
class QWenMLP(nn.Module):
55
56
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
57
58
59
60
61
62

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
63
        quant_config: QuantizationConfig | None = None,
64
        prefix: str = "",
65
66
67
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
68
69
70
71
72
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
73
74
        )
        self.c_proj = RowParallelLinear(
75
76
77
78
79
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
80
        )
81
        if hidden_act != "silu":
82
83
84
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
85
86
        self.act_fn = SiluAndMul()

87
    def forward(self, x: torch.Tensor) -> torch.Tensor:
88
89
90
91
92
93
94
95
96
97
98
99
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
100
        rope_parameters: dict[str, Any] | None = None,
101
102
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
103
        prefix: str = "",
104
105
106
    ):
        super().__init__()
        self.hidden_size = hidden_size
107
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
108
109
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
110
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
111
112
113
114
115
116
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
117
            quant_config=quant_config,
118
            prefix=f"{prefix}.c_attn",
119
120
121
122
123
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
124
            quant_config=quant_config,
125
            prefix=f"{prefix}.c_proj",
126
127
128
129
130
131
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
132
            rope_parameters=rope_parameters,
133
        )
134
135
136
137
138
139
140
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
141
        )
142
143
144
145
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
146
147
148
149
150
151
152

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
153
154
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
155
156
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
157
        attn_output = self.attn(q, k, v)
158
159
160
161
162
163
164
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):
    def __init__(
        self,
165
        config: PretrainedConfig,
166
167
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
168
        prefix: str = "",
169
170
171
172
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

173
174
175
176
        self.attn = QWenAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.max_position_embeddings,
177
            rope_parameters=config.rope_parameters,
178
179
180
181
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
182
183
184

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

185
        self.mlp = QWenMLP(
186
187
188
189
            config.hidden_size,
            config.intermediate_size // 2,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
190
        )
191
192
193
194
195

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
196
        residual: torch.Tensor | None,
197
    ) -> tuple[torch.Tensor, torch.Tensor]:
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


215
@support_torch_compile
216
class QWenModel(nn.Module):
217
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
218
        super().__init__()
219
220
221
222
223

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

224
225
226
227
228
229
230
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
231
232
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
233
234
235
            lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
236
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
237
238
239
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
240

241
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
242
243
        return self.wte(input_ids)

244
245
    def forward(
        self,
246
        input_ids: torch.Tensor | None,
247
        positions: torch.Tensor,
248
249
250
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
251
        if get_pp_group().is_first_rank:
252
253
254
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
255
                hidden_states = self.embed_input_ids(input_ids)
256
257
258
259
260
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
261

262
        for layer in islice(self.h, self.start_layer, self.end_layer):
263
264
265
266
267
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
268
        if not get_pp_group().is_last_rank:
269
270
271
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
272
273
274
275
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


276
class QWenBaseModel(nn.Module):
277
278
    def __init__(
        self,
279
280
281
282
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
283
284
    ) -> None:
        super().__init__()
285
286
287
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
288
        self.config = config
289
        self.multimodal_config = multimodal_config
290
        self.quant_config = quant_config
291
292
293
294
295
296
297
298
299
        self.transformer = transformer_type(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
300
301
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
302
        self.logits_processor = LogitsProcessor(config.vocab_size)
303
        self.make_empty_intermediate_tensors = (
304
305
            self.transformer.make_empty_intermediate_tensors
        )
306
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
307
308
309
310
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
311
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
312
313
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
314
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
315

316
317
318
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.wte(input_ids)

319
320
321
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
322
    ) -> torch.Tensor | None:
323
        logits = self.logits_processor(self.lm_head, hidden_states)
324
325
        return logits

326
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
327
328
329
330
331
332
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
333
        loaded_params: set[str] = set()
334
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
335
336
            if "rotary_emb.inv_freq" in name:
                continue
337
            for param_name, weight_name, shard_id in stacked_params_mapping:
Qing's avatar
Qing committed
338
339
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
340
341
342
343
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
344
345
346
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
347
                param = params_dict[name]
348
349
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
350
                break
351
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
352
353
354
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
355
356
357
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
358
                param = params_dict[name]
359
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
360
                weight_loader(param, loaded_weight)
361
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
362

363
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
364
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
365
366
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
367
                "mlp.gate_up_proj.weight",
368
369
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
370
371
372
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
373
374
            # lay_qkv_words = ["attn.c_attn.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
375
            
zhuwenwen's avatar
zhuwenwen committed
376
377
            # lay_qkv_bias_words = ["attn.c_attn.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
zhuwenwen's avatar
zhuwenwen committed
378
                      
zhuwenwen's avatar
zhuwenwen committed
379
380
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
381
382
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
zhuwenwen's avatar
zhuwenwen committed
383
                
gaoqiong's avatar
gaoqiong committed
384
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
385
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
386
387
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
388
                        
zhuwenwen's avatar
zhuwenwen committed
389
390
391
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
392
                        
gaoqiong's avatar
gaoqiong committed
393
394
395
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
396
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
397
398
399
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
400
            
401
        return loaded_params
402
403


404
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
405
406
407
408
409
410
411
412
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

413
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
414
        config = vllm_config.model_config.hf_config
415
        if hasattr(config, "visual"):
416
            hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]}
417
418
419
420
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
421
422
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`"
            )
423
424

        super().__init__(vllm_config=vllm_config, prefix=prefix)
425

426
427
    def forward(
        self,
428
        input_ids: torch.Tensor | None,
429
        positions: torch.Tensor,
430
431
432
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
433
434
435
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
436
        return hidden_states