qwen.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9
import json
10
from collections.abc import Iterable
11
from itertools import islice
12
from typing import Any, Optional, Union
13

14
15
import torch
from torch import nn
16
from transformers import PretrainedConfig
Qing's avatar
Qing committed
17

18
from vllm.attention import Attention
19
from vllm.compilation.decorators import support_torch_compile
20
from vllm.config import CacheConfig, VllmConfig
21
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
22
from vllm.model_executor.layers.activation import SiluAndMul
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
25
26
                                               QKVParallelLinear,
                                               RowParallelLinear)
27
from vllm.model_executor.layers.logits_processor import LogitsProcessor
28
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.layers.rotary_embedding import get_rope
30
from vllm.model_executor.layers.vocab_parallel_embedding import (
31
    ParallelLMHead, VocabParallelEmbedding)
32
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
from vllm.model_executor.sampling_metadata import SamplingMetadata
34
from vllm.sequence import IntermediateTensors
35

36
37
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
38
                    make_empty_intermediate_tensors_factory, make_layers,
39
                    maybe_prefix)
40

41
42

class QWenMLP(nn.Module):
43
44
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
45
46
47
48
49
50

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
51
        quant_config: Optional[QuantizationConfig] = None,
52
53
54
55
56
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
57
            quant_config=quant_config)
58
59
60
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
61
                                        quant_config=quant_config)
62
63
64
65
66
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

67
    def forward(self, x: torch.Tensor) -> torch.Tensor:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
82
        rope_scaling: Optional[dict[str, Any]] = None,
83
        cache_config: Optional[CacheConfig] = None,
84
        quant_config: Optional[QuantizationConfig] = None,
85
        prefix: str = "",
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
101
            quant_config=quant_config,
102
103
104
105
106
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
107
            quant_config=quant_config,
108
109
110
111
112
113
114
115
116
117
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
118
119
120
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
121
                              cache_config=cache_config,
122
123
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
124
125
126
127
128
129
130
131
132

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
133
        attn_output = self.attn(q, k, v)
134
135
136
137
138
139
140
141
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
142
        config: PretrainedConfig,
143
        cache_config: Optional[CacheConfig] = None,
144
        quant_config: Optional[QuantizationConfig] = None,
145
        prefix: str = "",
146
147
148
149
150
151
152
153
154
155
156
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
157
                                  cache_config=cache_config,
158
159
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
160
161
162
163
164

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
165
                           quant_config=quant_config)
166
167
168
169
170
171

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
172
    ) -> tuple[torch.Tensor, torch.Tensor]:
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


190
@support_torch_compile
191
class QWenModel(nn.Module):
Qing's avatar
Qing committed
192

193
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
194
        super().__init__()
195
196
197
198
199

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

200
201
202
203
204
205
206
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
207
208
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
209
210
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
211
            prefix=f"{prefix}.h")
212
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
213
214
215
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
216

217
218
219
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

220
221
222
223
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
224
        intermediate_tensors: Optional[IntermediateTensors],
225
        inputs_embeds: Optional[torch.Tensor] = None,
226
    ) -> Union[torch.Tensor, IntermediateTensors]:
227
        if get_pp_group().is_first_rank:
228
229
230
231
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
232
233
234
235
236
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
237

238
        for layer in islice(self.h, self.start_layer, self.end_layer):
239
240
241
242
243
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
244
245
246
247
248
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
249
250
251
252
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


253
class QWenBaseModel(nn.Module):
254
255
256

    def __init__(
        self,
257
258
259
260
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
261
262
    ) -> None:
        super().__init__()
263
264
265
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
266
        self.config = config
267
        self.multimodal_config = multimodal_config
268
        self.quant_config = quant_config
269
270
271
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
272
273
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
274
275
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(prefix, "lm_head"))
276
277
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
278
        self.logits_processor = LogitsProcessor(config.vocab_size)
279
280
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
281

282
283
284
285
286
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
287
        logits = self.logits_processor(self.lm_head, hidden_states,
288
289
290
                                       sampling_metadata)
        return logits

291
292
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
293
294
295
296
297
298
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
299
        loaded_params: set[str] = set()
300
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
301
302
            if "rotary_emb.inv_freq" in name:
                continue
303
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
304
305
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
306
307
308
309
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
310
311
312
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
313
                param = params_dict[name]
314
315
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
316
                break
317
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
318
319
320
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
321
322
323
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
324
325
326
327
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
328
329
            loaded_params.add(name)
        return loaded_params
330
331


332
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
333
334
335
336
337
338
339
340
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

341
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
342
        config = vllm_config.model_config.hf_config
343
344
345
346
347
348
349
350
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
351
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
352
353

        super().__init__(vllm_config=vllm_config, prefix=prefix)
354

355
356
357
358
359
360
361
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
362
363
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
364
        return hidden_states