qwen.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9
import json
10
11
from collections.abc import Iterable
from typing import Any, Optional, Union
12

13
14
import torch
from torch import nn
15
from transformers import PretrainedConfig
Qing's avatar
Qing committed
16

17
from vllm.attention import Attention
18
from vllm.compilation.decorators import support_torch_compile
19
from vllm.config import CacheConfig, VllmConfig
20
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
21
from vllm.model_executor.layers.activation import SiluAndMul
22
from vllm.model_executor.layers.layernorm import RMSNorm
23
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
24
25
                                               QKVParallelLinear,
                                               RowParallelLinear)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
from vllm.model_executor.layers.quantization import QuantizationConfig
28
from vllm.model_executor.layers.rotary_embedding import get_rope
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.sequence import IntermediateTensors
34

35
36
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
37
                    make_empty_intermediate_tensors_factory, make_layers,
38
                    maybe_prefix)
39

40
41

class QWenMLP(nn.Module):
42
43
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
44
45
46
47
48
49

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
50
        quant_config: Optional[QuantizationConfig] = None,
51
52
53
54
55
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
56
            quant_config=quant_config)
57
58
59
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
60
                                        quant_config=quant_config)
61
62
63
64
65
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

66
    def forward(self, x: torch.Tensor) -> torch.Tensor:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
81
        rope_scaling: Optional[dict[str, Any]] = None,
82
        cache_config: Optional[CacheConfig] = None,
83
        quant_config: Optional[QuantizationConfig] = None,
84
        prefix: str = "",
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
100
            quant_config=quant_config,
101
102
103
104
105
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
106
            quant_config=quant_config,
107
108
109
110
111
112
113
114
115
116
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
117
118
119
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
120
                              cache_config=cache_config,
121
122
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
123
124
125
126
127
128
129
130
131

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
132
        attn_output = self.attn(q, k, v)
133
134
135
136
137
138
139
140
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
141
        config: PretrainedConfig,
142
        cache_config: Optional[CacheConfig] = None,
143
        quant_config: Optional[QuantizationConfig] = None,
144
        prefix: str = "",
145
146
147
148
149
150
151
152
153
154
155
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
156
                                  cache_config=cache_config,
157
158
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
159
160
161
162
163

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
164
                           quant_config=quant_config)
165
166
167
168
169
170

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
171
    ) -> tuple[torch.Tensor, torch.Tensor]:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


189
@support_torch_compile
190
class QWenModel(nn.Module):
Qing's avatar
Qing committed
191

192
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
193
        super().__init__()
194
195
196
197
198

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

199
200
201
202
203
204
205
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
206
207
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
208
209
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
210
            prefix=f"{prefix}.h")
211
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
212
213
214
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
215

216
217
218
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

219
220
221
222
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
223
        intermediate_tensors: Optional[IntermediateTensors],
224
        inputs_embeds: Optional[torch.Tensor] = None,
225
    ) -> Union[torch.Tensor, IntermediateTensors]:
226
        if get_pp_group().is_first_rank:
227
228
229
230
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
231
232
233
234
235
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
236

237
        for layer in self.h[self.start_layer:self.end_layer]:
238
239
240
241
242
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
243
244
245
246
247
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
248
249
250
251
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


252
class QWenBaseModel(nn.Module):
253
254
255

    def __init__(
        self,
256
257
258
259
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
260
261
    ) -> None:
        super().__init__()
262
263
264
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
265
        self.config = config
266
        self.multimodal_config = multimodal_config
267
        self.quant_config = quant_config
268
269
270
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
271
272
273
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
274
275
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
276
        self.logits_processor = LogitsProcessor(config.vocab_size)
277
278
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
279

280
281
282
283
284
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
285
        logits = self.logits_processor(self.lm_head, hidden_states,
286
287
288
                                       sampling_metadata)
        return logits

289
290
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
291
292
293
294
295
296
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
297
        loaded_params: set[str] = set()
298
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
299
300
            if "rotary_emb.inv_freq" in name:
                continue
301
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
302
303
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
304
305
306
307
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
308
309
310
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
311
                param = params_dict[name]
312
313
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
314
                break
315
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
316
317
318
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
319
320
321
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
322
323
324
325
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
326
327
            loaded_params.add(name)
        return loaded_params
328
329


330
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
331
332
333
334
335
336
337
338
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

339
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340
        config = vllm_config.model_config.hf_config
341
342
343
344
345
346
347
348
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
349
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
350
351

        super().__init__(vllm_config=vllm_config, prefix=prefix)
352

353
354
355
356
357
358
359
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
360
361
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
362
        return hidden_states