qwen.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9

10
import json
11
from collections.abc import Iterable
12
from itertools import islice
13
from typing import Any, Optional, Union
14

15
16
import torch
from torch import nn
17
from transformers import PretrainedConfig
Qing's avatar
Qing committed
18

19
from vllm.attention import Attention
20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, VllmConfig
22
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
23
from vllm.model_executor.layers.activation import SiluAndMul
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.vocab_parallel_embedding import (
34
35
36
    ParallelLMHead,
    VocabParallelEmbedding,
)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.sequence import IntermediateTensors
39

40
from .interfaces import SupportsLoRA, SupportsPP
41
42
43
44
45
46
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
47

48
49

class QWenMLP(nn.Module):
50
51
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
52
53
54
55
56
57

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
58
        quant_config: Optional[QuantizationConfig] = None,
59
60
61
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
62
63
64
65
66
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.c_proj = RowParallelLinear(
            intermediate_size, hidden_size, bias=False, quant_config=quant_config
        )
67
        if hidden_act != "silu":
68
69
70
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
71
72
        self.act_fn = SiluAndMul()

73
    def forward(self, x: torch.Tensor) -> torch.Tensor:
74
75
76
77
78
79
80
81
82
83
84
85
86
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
87
        rope_scaling: Optional[dict[str, Any]] = None,
88
        cache_config: Optional[CacheConfig] = None,
89
        quant_config: Optional[QuantizationConfig] = None,
90
        prefix: str = "",
91
92
93
    ):
        super().__init__()
        self.hidden_size = hidden_size
94
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
95
96
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
97
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
98
99
100
101
102
103
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
104
            quant_config=quant_config,
105
106
107
108
109
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
110
            quant_config=quant_config,
111
112
113
114
115
116
117
118
119
120
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
121
122
123
124
125
126
127
128
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
129
130
131
132
133
134
135
136
137

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
138
        attn_output = self.attn(q, k, v)
139
140
141
142
143
144
145
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):
    def __init__(
        self,
146
        config: PretrainedConfig,
147
        cache_config: Optional[CacheConfig] = None,
148
        quant_config: Optional[QuantizationConfig] = None,
149
        prefix: str = "",
150
151
152
153
154
155
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
156
157
158
159
160
161
162
163
164
165
        self.attn = QWenAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.max_position_embeddings,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
166
167
168

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

169
170
171
        self.mlp = QWenMLP(
            config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
        )
172
173
174
175
176
177

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
178
    ) -> tuple[torch.Tensor, torch.Tensor]:
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


196
@support_torch_compile
197
class QWenModel(nn.Module):
198
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199
        super().__init__()
200
201
202
203
204

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

205
206
207
208
209
210
211
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
212
213
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
214
215
216
            lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
217
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
218
219
220
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
221

222
223
224
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

225
226
227
228
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
229
        intermediate_tensors: Optional[IntermediateTensors],
230
        inputs_embeds: Optional[torch.Tensor] = None,
231
    ) -> Union[torch.Tensor, IntermediateTensors]:
232
        if get_pp_group().is_first_rank:
233
234
235
236
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
237
238
239
240
241
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
242

243
        for layer in islice(self.h, self.start_layer, self.end_layer):
244
245
246
247
248
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
249
        if not get_pp_group().is_last_rank:
250
251
252
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
253
254
255
256
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


257
class QWenBaseModel(nn.Module):
258
259
    def __init__(
        self,
260
261
262
263
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
264
265
    ) -> None:
        super().__init__()
266
267
268
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
269
        self.config = config
270
        self.multimodal_config = multimodal_config
271
        self.quant_config = quant_config
272
273
274
275
276
277
278
279
280
        self.transformer = transformer_type(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
281
282
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
283
        self.logits_processor = LogitsProcessor(config.vocab_size)
284
        self.make_empty_intermediate_tensors = (
285
286
            self.transformer.make_empty_intermediate_tensors
        )
287

288
289
290
291
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
292
        logits = self.logits_processor(self.lm_head, hidden_states)
293
294
        return logits

295
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
296
297
298
299
300
301
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
302
        loaded_params: set[str] = set()
303
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
304
305
            if "rotary_emb.inv_freq" in name:
                continue
306
            for param_name, weight_name, shard_id in stacked_params_mapping:
Qing's avatar
Qing committed
307
308
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
309
310
311
312
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
313
314
315
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
316
                param = params_dict[name]
317
318
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
319
                break
320
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
321
322
323
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
324
325
326
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
327
                param = params_dict[name]
328
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
329
                weight_loader(param, loaded_weight)
330
331
            loaded_params.add(name)
        return loaded_params
332
333


334
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
335
336
337
338
339
340
341
342
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

343
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
344
        config = vllm_config.model_config.hf_config
345
        if hasattr(config, "visual"):
346
            hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]}
347
348
349
350
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
351
352
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`"
            )
353
354

        super().__init__(vllm_config=vllm_config, prefix=prefix)
355

356
357
358
359
360
361
362
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
363
364
365
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
366
        return hidden_states