qwen.py 13.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Qing's avatar
Qing committed
3
4
5
6
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
7
"""Inference-only QWen model compatible with HuggingFace weights."""
8
import json
9
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
10

11
12
import torch
from torch import nn
13
from transformers import PretrainedConfig
Qing's avatar
Qing committed
14

15
from vllm.attention import Attention
16
from vllm.compilation.decorators import support_torch_compile
17
from vllm.config import CacheConfig, VllmConfig
18
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.model_executor.layers.rotary_embedding import get_rope
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
    ParallelLMHead, VocabParallelEmbedding)
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
from vllm.model_executor.sampling_metadata import SamplingMetadata
31
from vllm.sequence import IntermediateTensors
32

33
34
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
35
                    make_empty_intermediate_tensors_factory, make_layers,
36
                    maybe_prefix)
37

38
39

class QWenMLP(nn.Module):
40
41
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
42
43
44
45
46
47

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
48
        quant_config: Optional[QuantizationConfig] = None,
49
50
51
52
53
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
54
            quant_config=quant_config)
55
56
57
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
58
                                        quant_config=quant_config)
59
60
61
62
63
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

64
    def forward(self, x: torch.Tensor) -> torch.Tensor:
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
80
        cache_config: Optional[CacheConfig] = None,
81
        quant_config: Optional[QuantizationConfig] = None,
82
        prefix: str = "",
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
98
            quant_config=quant_config,
99
100
101
102
103
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
104
            quant_config=quant_config,
105
106
107
108
109
110
111
112
113
114
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
115
116
117
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
118
                              cache_config=cache_config,
119
120
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
121
122
123
124
125
126
127
128
129

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
130
        attn_output = self.attn(q, k, v)
131
132
133
134
135
136
137
138
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
139
        config: PretrainedConfig,
140
        cache_config: Optional[CacheConfig] = None,
141
        quant_config: Optional[QuantizationConfig] = None,
142
        prefix: str = "",
143
144
145
146
147
148
149
150
151
152
153
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
154
                                  cache_config=cache_config,
155
156
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
157
158
159
160
161

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
162
                           quant_config=quant_config)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


187
@support_torch_compile
188
class QWenModel(nn.Module):
Qing's avatar
Qing committed
189

190
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
191
        super().__init__()
192
193
194
195
196

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

197
198
199
200
201
202
203
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
204
205
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
206
207
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
208
            prefix=f"{prefix}.h")
209
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
210
211
212
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
213

214
215
216
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

217
218
219
220
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
221
        intermediate_tensors: Optional[IntermediateTensors],
222
        inputs_embeds: Optional[torch.Tensor] = None,
223
    ) -> Union[torch.Tensor, IntermediateTensors]:
224
        if get_pp_group().is_first_rank:
225
226
227
228
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
229
230
231
232
233
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
234

235
        for layer in self.h[self.start_layer:self.end_layer]:
236
237
238
239
240
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
241
242
243
244
245
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
246
247
248
249
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


250
class QWenBaseModel(nn.Module):
251
252
253

    def __init__(
        self,
254
255
256
257
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
258
259
    ) -> None:
        super().__init__()
260
261
262
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
263
        self.config = config
264
        self.multimodal_config = multimodal_config
265
        self.quant_config = quant_config
266
267
268
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
269
270
271
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
272
273
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
274
        self.logits_processor = LogitsProcessor(config.vocab_size)
275
276
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
277

278
279
280
281
282
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
283
        logits = self.logits_processor(self.lm_head, hidden_states,
284
285
286
                                       sampling_metadata)
        return logits

287
288
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
289
290
291
292
293
294
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
295
        loaded_params: Set[str] = set()
296
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
297
298
            if "rotary_emb.inv_freq" in name:
                continue
299
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
300
301
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
302
303
304
305
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
306
307
308
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
309
                param = params_dict[name]
310
311
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
312
                break
313
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
314
315
316
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
317
318
319
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
320
321
322
323
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
324
325
            loaded_params.add(name)
        return loaded_params
326
327


328
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
329
330
331
332
333
334
335
336
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

337
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
338
        config = vllm_config.model_config.hf_config
339
340
341
342
343
344
345
346
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
347
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
348
349

        super().__init__(vllm_config=vllm_config, prefix=prefix)
350

351
352
353
354
355
356
357
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
358
359
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
360
        return hidden_states