qwen.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9

10
import json
11
from collections.abc import Iterable
12
from itertools import islice
13
from typing import Any
14

15
16
import torch
from torch import nn
17
from transformers import PretrainedConfig
Qing's avatar
Qing committed
18

19
from vllm.attention.layer import Attention
20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, VllmConfig
22
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
23
from vllm.model_executor.layers.activation import SiluAndMul
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.vocab_parallel_embedding import (
34
35
36
    ParallelLMHead,
    VocabParallelEmbedding,
)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.sequence import IntermediateTensors
39

40
from .interfaces import SupportsLoRA, SupportsPP
41
42
43
44
45
46
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
47

48
49

class QWenMLP(nn.Module):
50
51
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
52
53
54
55
56
57

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
58
        quant_config: QuantizationConfig | None = None,
59
        prefix: str = "",
60
61
62
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
63
64
65
66
67
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
68
69
        )
        self.c_proj = RowParallelLinear(
70
71
72
73
74
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
75
        )
76
        if hidden_act != "silu":
77
78
79
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
80
81
        self.act_fn = SiluAndMul()

82
    def forward(self, x: torch.Tensor) -> torch.Tensor:
83
84
85
86
87
88
89
90
91
92
93
94
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
95
        rope_parameters: dict[str, Any] | None = None,
96
97
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
98
        prefix: str = "",
99
100
101
    ):
        super().__init__()
        self.hidden_size = hidden_size
102
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
103
104
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
105
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
106
107
108
109
110
111
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
112
            quant_config=quant_config,
113
            prefix=f"{prefix}.c_attn",
114
115
116
117
118
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
119
            quant_config=quant_config,
120
            prefix=f"{prefix}.c_proj",
121
122
123
124
125
126
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
127
            rope_parameters=rope_parameters,
128
        )
129
130
131
132
133
134
135
136
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
137
138
139
140
141
142
143
144
145

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
146
        attn_output = self.attn(q, k, v)
147
148
149
150
151
152
153
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):
    def __init__(
        self,
154
        config: PretrainedConfig,
155
156
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
157
        prefix: str = "",
158
159
160
161
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

162
163
164
165
        self.attn = QWenAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.max_position_embeddings,
166
            rope_parameters=config.rope_parameters,
167
168
169
170
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
171
172
173

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

174
        self.mlp = QWenMLP(
175
176
177
178
            config.hidden_size,
            config.intermediate_size // 2,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
179
        )
180
181
182
183
184

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
185
        residual: torch.Tensor | None,
186
    ) -> tuple[torch.Tensor, torch.Tensor]:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


204
@support_torch_compile
205
class QWenModel(nn.Module):
206
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207
        super().__init__()
208
209
210
211
212

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

213
214
215
216
217
218
219
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
220
221
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
222
223
224
            lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
225
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
226
227
228
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
229

230
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
231
232
        return self.wte(input_ids)

233
234
    def forward(
        self,
235
        input_ids: torch.Tensor | None,
236
        positions: torch.Tensor,
237
238
239
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
240
        if get_pp_group().is_first_rank:
241
242
243
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
244
                hidden_states = self.embed_input_ids(input_ids)
245
246
247
248
249
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
250

251
        for layer in islice(self.h, self.start_layer, self.end_layer):
252
253
254
255
256
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
257
        if not get_pp_group().is_last_rank:
258
259
260
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
261
262
263
264
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


265
class QWenBaseModel(nn.Module):
266
267
    def __init__(
        self,
268
269
270
271
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
272
273
    ) -> None:
        super().__init__()
274
275
276
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
277
        self.config = config
278
        self.multimodal_config = multimodal_config
279
        self.quant_config = quant_config
280
281
282
283
284
285
286
287
288
        self.transformer = transformer_type(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
289
290
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
291
        self.logits_processor = LogitsProcessor(config.vocab_size)
292
        self.make_empty_intermediate_tensors = (
293
294
            self.transformer.make_empty_intermediate_tensors
        )
295

296
297
298
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.wte(input_ids)

299
300
301
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
302
    ) -> torch.Tensor | None:
303
        logits = self.logits_processor(self.lm_head, hidden_states)
304
305
        return logits

306
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
307
308
309
310
311
312
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
313
        loaded_params: set[str] = set()
314
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
315
316
            if "rotary_emb.inv_freq" in name:
                continue
317
            for param_name, weight_name, shard_id in stacked_params_mapping:
Qing's avatar
Qing committed
318
319
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
320
321
322
323
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
324
325
326
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
327
                param = params_dict[name]
328
329
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
330
                break
331
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
332
333
334
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
335
336
337
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
338
                param = params_dict[name]
339
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
340
                weight_loader(param, loaded_weight)
341
342
            loaded_params.add(name)
        return loaded_params
343
344


345
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
346
347
348
349
350
351
352
353
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

354
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
355
        config = vllm_config.model_config.hf_config
356
        if hasattr(config, "visual"):
357
            hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]}
358
359
360
361
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
362
363
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`"
            )
364
365

        super().__init__(vllm_config=vllm_config, prefix=prefix)
366

367
368
    def forward(
        self,
369
        input_ids: torch.Tensor | None,
370
        positions: torch.Tensor,
371
372
373
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
374
375
376
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
377
        return hidden_states