qwen.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9

10
import json
11
from collections.abc import Iterable
12
from itertools import islice
13
from typing import Any
14

15
16
import torch
from torch import nn
17
from transformers import PretrainedConfig
Qing's avatar
Qing committed
18

19
from vllm.attention.layer import Attention
20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, VllmConfig
22
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
23
from vllm.model_executor.layers.activation import SiluAndMul
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.vocab_parallel_embedding import (
34
35
36
    ParallelLMHead,
    VocabParallelEmbedding,
)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.sequence import IntermediateTensors
39

40
from .interfaces import SupportsLoRA, SupportsPP
41
42
43
44
45
46
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
47

48
49

class QWenMLP(nn.Module):
50
51
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
52
53
54
55
56
57

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
58
        quant_config: QuantizationConfig | None = None,
59
60
61
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
62
63
64
65
66
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.c_proj = RowParallelLinear(
            intermediate_size, hidden_size, bias=False, quant_config=quant_config
        )
67
        if hidden_act != "silu":
68
69
70
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
71
72
        self.act_fn = SiluAndMul()

73
    def forward(self, x: torch.Tensor) -> torch.Tensor:
74
75
76
77
78
79
80
81
82
83
84
85
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
86
        rope_parameters: dict[str, Any] | None = None,
87
88
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
89
        prefix: str = "",
90
91
92
    ):
        super().__init__()
        self.hidden_size = hidden_size
93
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
94
95
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
96
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
97
98
99
100
101
102
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
103
            quant_config=quant_config,
104
            prefix=f"{prefix}.c_attn",
105
106
107
108
109
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
110
            quant_config=quant_config,
111
            prefix=f"{prefix}.c_proj",
112
113
114
115
116
117
118
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
119
            rope_parameters=rope_parameters,
120
        )
121
122
123
124
125
126
127
128
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
129
130
131
132
133
134
135
136
137

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
138
        attn_output = self.attn(q, k, v)
139
140
141
142
143
144
145
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):
    def __init__(
        self,
146
        config: PretrainedConfig,
147
148
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
149
        prefix: str = "",
150
151
152
153
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

154
155
156
157
        self.attn = QWenAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.max_position_embeddings,
158
            rope_parameters=config.rope_parameters,
159
160
161
162
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
163
164
165

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

166
167
168
        self.mlp = QWenMLP(
            config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
        )
169
170
171
172
173

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
174
        residual: torch.Tensor | None,
175
    ) -> tuple[torch.Tensor, torch.Tensor]:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


193
@support_torch_compile
194
class QWenModel(nn.Module):
195
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
196
        super().__init__()
197
198
199
200
201

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

202
203
204
205
206
207
208
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
209
210
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
211
212
213
            lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
214
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
215
216
217
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
218

219
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
220
221
        return self.wte(input_ids)

222
223
224
225
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
226
227
228
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
229
        if get_pp_group().is_first_rank:
230
231
232
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
233
                hidden_states = self.embed_input_ids(input_ids)
234
235
236
237
238
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
239

240
        for layer in islice(self.h, self.start_layer, self.end_layer):
241
242
243
244
245
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
246
        if not get_pp_group().is_last_rank:
247
248
249
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
250
251
252
253
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


254
class QWenBaseModel(nn.Module):
255
256
    def __init__(
        self,
257
258
259
260
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
261
262
    ) -> None:
        super().__init__()
263
264
265
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
266
        self.config = config
267
        self.multimodal_config = multimodal_config
268
        self.quant_config = quant_config
269
270
271
272
273
274
275
276
277
        self.transformer = transformer_type(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
278
279
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
280
        self.logits_processor = LogitsProcessor(config.vocab_size)
281
        self.make_empty_intermediate_tensors = (
282
283
            self.transformer.make_empty_intermediate_tensors
        )
284

285
286
287
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
288
    ) -> torch.Tensor | None:
289
        logits = self.logits_processor(self.lm_head, hidden_states)
290
291
        return logits

292
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
293
294
295
296
297
298
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
299
        loaded_params: set[str] = set()
300
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
301
302
            if "rotary_emb.inv_freq" in name:
                continue
303
            for param_name, weight_name, shard_id in stacked_params_mapping:
Qing's avatar
Qing committed
304
305
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
306
307
308
309
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
310
311
312
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
313
                param = params_dict[name]
314
315
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
316
                break
317
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
318
319
320
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
321
322
323
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
324
                param = params_dict[name]
325
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
326
                weight_loader(param, loaded_weight)
327
328
            loaded_params.add(name)
        return loaded_params
329
330


331
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
332
333
334
335
336
337
338
339
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

340
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
341
        config = vllm_config.model_config.hf_config
342
        if hasattr(config, "visual"):
343
            hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]}
344
345
346
347
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
348
349
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`"
            )
350
351

        super().__init__(vllm_config=vllm_config, prefix=prefix)
352

353
354
355
356
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
357
358
359
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
360
361
362
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
363
        return hidden_states