qwen.py 13.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Qing's avatar
Qing committed
3
4
5
6
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
7
"""Inference-only QWen model compatible with HuggingFace weights."""
8
import json
9
10
from collections.abc import Iterable
from typing import Any, Optional, Union
11

12
13
import torch
from torch import nn
14
from transformers import PretrainedConfig
Qing's avatar
Qing committed
15

16
from vllm.attention import Attention
17
from vllm.compilation.decorators import support_torch_compile
18
from vllm.config import CacheConfig, VllmConfig
19
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
20
from vllm.model_executor.layers.activation import SiluAndMul
21
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
23
24
                                               QKVParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
    ParallelLMHead, VocabParallelEmbedding)
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.model_executor.sampling_metadata import SamplingMetadata
32
from vllm.sequence import IntermediateTensors
33

34
35
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
36
                    make_empty_intermediate_tensors_factory, make_layers,
37
                    maybe_prefix)
38

39
40

class QWenMLP(nn.Module):
41
42
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
43
44
45
46
47
48

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
49
        quant_config: Optional[QuantizationConfig] = None,
50
51
52
53
54
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
55
            quant_config=quant_config)
56
57
58
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
59
                                        quant_config=quant_config)
60
61
62
63
64
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

65
    def forward(self, x: torch.Tensor) -> torch.Tensor:
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
80
        rope_scaling: Optional[dict[str, Any]] = None,
81
        cache_config: Optional[CacheConfig] = None,
82
        quant_config: Optional[QuantizationConfig] = None,
83
        prefix: str = "",
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
99
            quant_config=quant_config,
100
101
102
103
104
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
105
            quant_config=quant_config,
106
107
108
109
110
111
112
113
114
115
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
116
117
118
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
119
                              cache_config=cache_config,
120
121
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
122
123
124
125
126
127
128
129
130

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
131
        attn_output = self.attn(q, k, v)
132
133
134
135
136
137
138
139
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
140
        config: PretrainedConfig,
141
        cache_config: Optional[CacheConfig] = None,
142
        quant_config: Optional[QuantizationConfig] = None,
143
        prefix: str = "",
144
145
146
147
148
149
150
151
152
153
154
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
155
                                  cache_config=cache_config,
156
157
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
158
159
160
161
162

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
163
                           quant_config=quant_config)
164
165
166
167
168
169

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
170
    ) -> tuple[torch.Tensor, torch.Tensor]:
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


188
@support_torch_compile
189
class QWenModel(nn.Module):
Qing's avatar
Qing committed
190

191
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
192
        super().__init__()
193
194
195
196
197

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

198
199
200
201
202
203
204
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
205
206
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
207
208
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
209
            prefix=f"{prefix}.h")
210
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
211
212
213
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
214

215
216
217
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

218
219
220
221
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
222
        intermediate_tensors: Optional[IntermediateTensors],
223
        inputs_embeds: Optional[torch.Tensor] = None,
224
    ) -> Union[torch.Tensor, IntermediateTensors]:
225
        if get_pp_group().is_first_rank:
226
227
228
229
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
230
231
232
233
234
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
235

236
        for layer in self.h[self.start_layer:self.end_layer]:
237
238
239
240
241
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
242
243
244
245
246
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
247
248
249
250
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


251
class QWenBaseModel(nn.Module):
252
253
254

    def __init__(
        self,
255
256
257
258
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
259
260
    ) -> None:
        super().__init__()
261
262
263
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
264
        self.config = config
265
        self.multimodal_config = multimodal_config
266
        self.quant_config = quant_config
267
268
269
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
270
271
272
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
273
274
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
275
        self.logits_processor = LogitsProcessor(config.vocab_size)
276
277
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
278

279
280
281
282
283
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
284
        logits = self.logits_processor(self.lm_head, hidden_states,
285
286
287
                                       sampling_metadata)
        return logits

288
289
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
290
291
292
293
294
295
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
296
        loaded_params: set[str] = set()
297
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
298
299
            if "rotary_emb.inv_freq" in name:
                continue
300
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
301
302
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
303
304
305
306
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
307
308
309
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
310
                param = params_dict[name]
311
312
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
313
                break
314
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
315
316
317
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
318
319
320
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
321
322
323
324
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
325
326
            loaded_params.add(name)
        return loaded_params
327
328


329
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
330
331
332
333
334
335
336
337
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

338
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
339
        config = vllm_config.model_config.hf_config
340
341
342
343
344
345
346
347
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
348
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`")
349
350

        super().__init__(vllm_config=vllm_config, prefix=prefix)
351

352
353
354
355
356
357
358
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
359
360
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
361
        return hidden_states