"docs/backends/vllm/gpt-oss.md" did not exist on "77e66ae5c1ceb97ca1d6afdab29866dab63ed394"
qwen.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Qing's avatar
Qing committed
4
5
6
7
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
8
"""Inference-only QWen model compatible with HuggingFace weights."""
9

10
import json
11
from collections.abc import Iterable
12
from itertools import islice
13
from typing import Any
14

15
16
import torch
from torch import nn
17
from transformers import PretrainedConfig
Qing's avatar
Qing committed
18

19
from vllm.attention.layer import Attention
20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, VllmConfig
22
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
23
from vllm.model_executor.layers.activation import SiluAndMul
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.vocab_parallel_embedding import (
34
35
36
    ParallelLMHead,
    VocabParallelEmbedding,
)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.sequence import IntermediateTensors
39

40
from .interfaces import SupportsLoRA, SupportsPP
41
42
43
44
45
46
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
47

48
49

class QWenMLP(nn.Module):
50
51
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
52
53
54
55
56
57

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
58
        quant_config: QuantizationConfig | None = None,
59
60
61
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
62
63
64
65
66
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.c_proj = RowParallelLinear(
            intermediate_size, hidden_size, bias=False, quant_config=quant_config
        )
67
        if hidden_act != "silu":
68
69
70
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
71
72
        self.act_fn = SiluAndMul()

73
    def forward(self, x: torch.Tensor) -> torch.Tensor:
74
75
76
77
78
79
80
81
82
83
84
85
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
86
        rope_parameters: dict[str, Any] | None = None,
87
88
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
89
        prefix: str = "",
90
91
92
    ):
        super().__init__()
        self.hidden_size = hidden_size
93
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
94
95
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
96
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
97
98
99
100
101
102
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
103
            quant_config=quant_config,
104
            prefix=f"{prefix}.c_attn",
105
106
107
108
109
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
110
            quant_config=quant_config,
111
            prefix=f"{prefix}.c_proj",
112
113
114
115
116
117
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
118
            rope_parameters=rope_parameters,
119
        )
120
121
122
123
124
125
126
127
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
128
129
130
131
132
133
134
135
136

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
137
        attn_output = self.attn(q, k, v)
138
139
140
141
142
143
144
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):
    def __init__(
        self,
145
        config: PretrainedConfig,
146
147
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
148
        prefix: str = "",
149
150
151
152
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

153
154
155
156
        self.attn = QWenAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.max_position_embeddings,
157
            rope_parameters=config.rope_parameters,
158
159
160
161
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
162
163
164

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

165
166
167
        self.mlp = QWenMLP(
            config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
        )
168
169
170
171
172

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
173
        residual: torch.Tensor | None,
174
    ) -> tuple[torch.Tensor, torch.Tensor]:
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


192
@support_torch_compile
193
class QWenModel(nn.Module):
194
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
195
        super().__init__()
196
197
198
199
200

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

201
202
203
204
205
206
207
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
208
209
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
210
211
212
            lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
213
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
214
215
216
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
217

218
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
219
220
        return self.wte(input_ids)

221
222
223
224
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
225
226
227
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
228
        if get_pp_group().is_first_rank:
229
230
231
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
232
                hidden_states = self.embed_input_ids(input_ids)
233
234
235
236
237
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
238

239
        for layer in islice(self.h, self.start_layer, self.end_layer):
240
241
242
243
244
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
245
        if not get_pp_group().is_last_rank:
246
247
248
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
249
250
251
252
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


253
class QWenBaseModel(nn.Module):
254
255
    def __init__(
        self,
256
257
258
259
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QWenModel] = QWenModel,
260
261
    ) -> None:
        super().__init__()
262
263
264
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
265
        self.config = config
266
        self.multimodal_config = multimodal_config
267
        self.quant_config = quant_config
268
269
270
271
272
273
274
275
276
        self.transformer = transformer_type(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
277
278
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
279
        self.logits_processor = LogitsProcessor(config.vocab_size)
280
        self.make_empty_intermediate_tensors = (
281
282
            self.transformer.make_empty_intermediate_tensors
        )
283

284
285
286
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.wte(input_ids)

287
288
289
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
290
    ) -> torch.Tensor | None:
291
        logits = self.logits_processor(self.lm_head, hidden_states)
292
293
        return logits

294
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
295
296
297
298
299
300
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
301
        loaded_params: set[str] = set()
302
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
303
304
            if "rotary_emb.inv_freq" in name:
                continue
305
            for param_name, weight_name, shard_id in stacked_params_mapping:
Qing's avatar
Qing committed
306
307
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
308
309
310
311
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
312
313
314
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
315
                param = params_dict[name]
316
317
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
318
                break
319
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
320
321
322
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
323
324
325
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
326
                param = params_dict[name]
327
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
328
                weight_loader(param, loaded_weight)
329
330
            loaded_params.add(name)
        return loaded_params
331
332


333
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
334
335
336
337
338
339
340
341
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

342
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
343
        config = vllm_config.model_config.hf_config
344
        if hasattr(config, "visual"):
345
            hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]}
346
347
348
349
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
350
351
                f"`--hf-overrides '{json.dumps(hf_overrides)}'`"
            )
352
353

        super().__init__(vllm_config=vllm_config, prefix=prefix)
354

355
356
357
358
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
359
360
361
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
362
363
364
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
365
        return hidden_states