qwen3.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable
27
from typing import Any
28
29
30
31
32

import torch
from torch import nn
from transformers import Qwen3Config

33
34
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
35
36
37
38
39
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
40
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
41
42
43
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
45
from vllm.sequence import IntermediateTensors
46
from vllm.transformers_utils.config import set_default_rope_theta
47

48
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
49
50
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
51
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
52
53
54
55
56

logger = init_logger(__name__)


class Qwen3Attention(nn.Module):
57
58
59
60
61
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
62
        rope_parameters: dict,
63
        max_position: int = 4096 * 32,
64
        head_dim: int | None = None,
65
66
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
67
68
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
69
70
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
71
        dual_chunk_attention_config: dict[str, Any] | None = None,
72
    ) -> None:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
93
        self.dual_chunk_attention_config = dual_chunk_attention_config
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
116
            rope_parameters=rope_parameters,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            attn_type=attn_type,
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
131
132
133
            }
            if dual_chunk_attention_config
            else {},
134
135
136
137
138
139
140
141
142
143
144
145
        )
        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
146
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
147
        q_by_head = self.q_norm(q_by_head)
148
        q = q_by_head.view(q.shape)
149
        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
150
        k_by_head = self.k_norm(k_by_head)
151
152
153
154
155
156
157
158
159
160
161
        k = k_by_head.view(k.shape)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3DecoderLayer(nn.Module):
    def __init__(
        self,
        config: Qwen3Config,
162
163
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
164
165
166
167
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
168
        set_default_rope_theta(config, default_theta=1000000)
169
170
171
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

        # By default, Qwen3 uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

        self.self_attn = Qwen3Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rms_norm_eps=config.rms_norm_eps,
188
189
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
190
191
            cache_config=cache_config,
            quant_config=quant_config,
192
            rope_parameters=config.rope_parameters,
193
194
            prefix=f"{prefix}.self_attn",
            attn_type=attn_type,
195
            dual_chunk_attention_config=dual_chunk_attention_config,
196
197
198
199
200
201
202
203
        )
        self.mlp = Qwen3MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
204
205
206
207
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
208
209
210
211
212

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
213
        residual: torch.Tensor | None,
214
    ) -> tuple[torch.Tensor, torch.Tensor]:
215
216
217
218
219
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
220
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
221
222
223
224
225
226
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
227
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


ALL_DECODER_LAYER_TYPES = {
    "attention": Qwen3DecoderLayer,
}


@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
245
246
    }
)
247
248
class Qwen3Model(Qwen2Model):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
249
250
251
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, decoder_layer_type=Qwen3DecoderLayer
        )
252
253


254
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config

        self.quant_config = quant_config
275
276
277
        self.model = Qwen3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
278
279
280
281
282

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
283
284
285
286
287
288
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix=maybe_prefix(prefix, "lm_head"),
                )
289
290
291
292
293
294
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
295
296
            self.model.make_empty_intermediate_tensors
        )
297

298
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
299
300
        self.model.aux_hidden_state_layers = layers

301
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
302
303
304
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

305
306
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
307
308
309
310
311

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
312
313
314
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
315
316
317
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
318
319
320
321
322
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
323
    ) -> torch.Tensor | None:
324
        logits = self.logits_processor(self.lm_head, hidden_states)
325
326
        return logits

327
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
328
329
        loader = AutoWeightsLoader(
            self,
330
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
331
332
        )
        return loader.load_weights(weights)