qwen.py 14.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Qing's avatar
Qing committed
3
4
5
6
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
7
"""Inference-only QWen model compatible with HuggingFace weights."""
Qing's avatar
Qing committed
8

9
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
10

11
12
import torch
from torch import nn
13
from transformers import PretrainedConfig
Qing's avatar
Qing committed
14

15
from vllm.attention import Attention, AttentionMetadata
16
from vllm.compilation.decorators import support_torch_compile
17
from vllm.config import CacheConfig, VllmConfig
18
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
27
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
28
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
    ParallelLMHead, VocabParallelEmbedding)
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.model_executor.sampling_metadata import SamplingMetadata
32
from vllm.sequence import IntermediateTensors
33

34
35
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
36
                    make_empty_intermediate_tensors_factory, make_layers,
37
                    maybe_prefix)
38

39
40

class QWenMLP(nn.Module):
41
42
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
43
44
45
46
47
48

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
49
        quant_config: Optional[QuantizationConfig] = None,
50
51
52
53
54
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
55
            quant_config=quant_config)
56
57
58
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
59
                                        quant_config=quant_config)
60
61
62
63
64
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

65
    def forward(self, x: torch.Tensor) -> torch.Tensor:
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
81
        cache_config: Optional[CacheConfig] = None,
82
        quant_config: Optional[QuantizationConfig] = None,
83
        prefix: str = "",
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
99
            quant_config=quant_config,
100
101
102
103
104
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
105
            quant_config=quant_config,
106
107
108
109
110
111
112
113
114
115
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
116
117
118
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
119
                              cache_config=cache_config,
120
121
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
122
123
124
125
126

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
127
128
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
129
130
131
132
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
133
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
134
135
136
137
138
139
140
141
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
142
        config: PretrainedConfig,
143
        cache_config: Optional[CacheConfig] = None,
144
        quant_config: Optional[QuantizationConfig] = None,
145
        prefix: str = "",
146
147
148
149
150
151
152
153
154
155
156
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
157
                                  cache_config=cache_config,
158
159
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
160
161
162
163
164

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
165
                           quant_config=quant_config)
166
167
168
169
170

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
171
172
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
173
174
175
176
177
178
179
180
181
182
183
184
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
185
            attn_metadata=attn_metadata,
186
187
188
189
190
191
192
193
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


194
@support_torch_compile
195
class QWenModel(nn.Module):
Qing's avatar
Qing committed
196

197
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
198
        super().__init__()
199
200
201
202
203

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

204
205
206
207
208
209
210
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
211
212
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
213
214
            lambda prefix: QWenBlock(
                config, cache_config, quant_config, prefix=prefix),
215
            prefix=f"{prefix}.h")
216
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
217
218
219
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
220

221
222
223
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

224
225
226
227
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
228
229
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
230
        intermediate_tensors: Optional[IntermediateTensors],
231
        inputs_embeds: Optional[torch.Tensor] = None,
232
    ) -> Union[torch.Tensor, IntermediateTensors]:
233
        if get_pp_group().is_first_rank:
234
235
236
237
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
238
239
240
241
242
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
243

244
        for i in range(self.start_layer, self.end_layer):
245
246
247
248
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
249
                kv_caches[i - self.start_layer],
250
                attn_metadata,
251
252
                residual,
            )
253
254
255
256
257
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
258
259
260
261
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


262
class QWenBaseModel(nn.Module):
263
264
265

    def __init__(
        self,
266
267
268
269
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
270
271
    ) -> None:
        super().__init__()
272
273
274
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
275
        self.config = config
276
        self.multimodal_config = multimodal_config
277
        self.quant_config = quant_config
278
279
280
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
281
282
283
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
284
285
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
286
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
287
        self.sampler = get_sampler()
288
289
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
290

291
292
293
294
295
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
296
        logits = self.logits_processor(self.lm_head, hidden_states,
297
298
299
                                       sampling_metadata)
        return logits

300
301
    def sample(
        self,
302
        logits: torch.Tensor,
303
304
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
305
        next_tokens = self.sampler(logits, sampling_metadata)
306
        return next_tokens
Qing's avatar
Qing committed
307

308
309
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
310
311
312
313
314
315
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
316
        loaded_params: Set[str] = set()
317
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
318
319
            if "rotary_emb.inv_freq" in name:
                continue
320
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
321
322
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
323
324
325
326
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
327
328
329
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
330
                param = params_dict[name]
331
332
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
333
                break
334
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
335
336
337
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
338
339
340
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
341
342
343
344
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
345
346
            loaded_params.add(name)
        return loaded_params
347
348


349
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "c_attn",
        "gate_up_proj",
        "c_proj",
    ]

    embedding_modules = {}
    embedding_padding_modules = []

367
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
368
        config = vllm_config.model_config.hf_config
369
370
371
372
373
374
375
376
377
378
379
        if hasattr(config, "visual"):
            hf_overrides = {
                "architectures": ["QwenVLForConditionalGeneration"]
            }
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
                f"`--hf-overrides {hf_overrides!r}`")

        super().__init__(vllm_config=vllm_config, prefix=prefix)
380

381
382
383
384
385
386
387
388
389
390
391
392
393
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
        return hidden_states