qwen3_moe.py 30.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30
31

import torch
32
import torch.nn.functional as F
33
34
35
from torch import nn

from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
43
44
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
45
from vllm.model_executor.layers.attention import Attention
46
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
47
from vllm.model_executor.layers.layernorm import RMSNorm
48
49
50
51
52
53
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
54
55
56
57
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
59
60
    ParallelLMHead,
    VocabParallelEmbedding,
)
61
from vllm.model_executor.model_loader.weight_utils import (
62
63
64
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
65
from vllm.model_executor.models.utils import sequence_parallel_chunk
66
67
from vllm.sequence import IntermediateTensors

68
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
69
70
71
72
73
74
75
76
77
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
78
79
80
81
82
83
84
85
86
87

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
88
        quant_config: QuantizationConfig | None = None,
89
        reduce_results: bool = True,
90
        expert_gate: torch.nn.Linear | None = None,
91
92
93
94
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
95
96
            hidden_size,
            [intermediate_size] * 2,
97
98
            bias=False,
            quant_config=quant_config,
99
100
101
102
103
104
105
106
107
108
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
109
        if hidden_act != "silu":
110
111
112
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
113
        self.act_fn = SiluAndMul()
114
        self.expert_gate = expert_gate
115
116
117

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
118
119
120
121
122
123
124
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)

        if self.expert_gate is not None:
            out = F.sigmoid(self.expert_gate(x)[0]) * out

        return out
125
126
127
128
129


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
130
        vllm_config: VllmConfig,
131
132
133
        prefix: str = "",
    ):
        super().__init__()
134

135
        config = vllm_config.model_config.hf_text_config
136
137
138
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

139
140
        self.tp_size = get_tensor_model_parallel_world_size()

141
        self.ep_group = get_ep_group().device_group
142
        self.ep_rank = get_ep_group().rank_in_group
143
144
145
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

146
147
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

148
149
150
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
151
152
                f"the number of experts {config.num_experts}."
            )
153

154
155
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
156
        eplb_config = vllm_config.parallel_config.eplb_config
157
        self.enable_eplb = parallel_config.enable_eplb
158
159

        self.n_logical_experts = self.n_routed_experts
160
        self.n_redundant_experts = eplb_config.num_redundant_experts
161
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
162
163
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

164
165
166
167
168
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )

        shared_expert_intermediate_size = getattr(
            config, "shared_expert_intermediate_size", 0
        )
        if shared_expert_intermediate_size > 0:
            self.shared_expert_gate = ReplicatedLinear(
                config.hidden_size,
                1,
                bias=False,
                quant_config=None,
                prefix=f"{prefix}.shared_expert_gate",
            )
            self.shared_expert = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
                prefix=f"{prefix}.shared_expert",
            )
        else:
            self.shared_expert_gate = None
            self.shared_expert = None

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
            gate=self.gate,
204
205
206
207
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
208
            reduce_results=False,
209
210
211
212
213
214
215
216
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )

217
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
218
219
220
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
221
        is_input_1d = hidden_states.dim() == 1
222
        num_tokens, hidden_dim = hidden_states.shape
223
224
        hidden_states = hidden_states.view(-1, hidden_dim)

225
226
227
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

228
229
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
230
        shared_out, fused_out = self.experts(
231
232
            hidden_states=hidden_states, router_logits=router_logits
        )
233
234
235
        final_hidden_states = (
            shared_out + fused_out if shared_out is not None else fused_out
        )
236

237
238
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
239
240
                final_hidden_states, 0
            )
241
            final_hidden_states = final_hidden_states[:num_tokens]
242
243
244
245
        elif self.tp_size > 1:
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
                final_hidden_states
            )
246

247
        # return to 1d if input is 1d
248
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
249
250
251
252
253
254
255
256


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
257
        rope_parameters: dict[str, Any],
258
        max_position_embeddings: int = 8192,
259
        head_dim: int | None = None,
260
261
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
262
263
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
264
        prefix: str = "",
265
        dual_chunk_attention_config: dict[str, Any] | None = None,
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
288
        self.dual_chunk_attention_config = dual_chunk_attention_config
289

290
291
292
293
294
295
296
297
298
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
299

300
301
302
303
304
305
306
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
307
308
309
310

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
311
            rope_parameters=rope_parameters,
312
313
314
315
316
317
318
319
320
321
322
323
324
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
325
326
327
            }
            if dual_chunk_attention_config
            else {},
328
329
330
331
332
333
334
335
336
337
338
339
340
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
341
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
342
        q_by_head = self.q_norm(q_by_head)
343
344
        q = q_by_head.view(q.shape)

345
        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
346
        k_by_head = self.k_norm(k_by_head)
347
348
349
350
351
352
353
354
        k = k_by_head.view(k.shape)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):
355
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
356
        super().__init__()
357

358
        config = vllm_config.model_config.hf_text_config
359
360
361
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

362
        self.hidden_size = config.hidden_size
363
364
365
366
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
367
368
369
370
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
371
            rope_parameters=config.rope_parameters,
372
373
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
374
375
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
376
377
378
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
379
            dual_chunk_attention_config=dual_chunk_attention_config,
380
381
382
383
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
384
385
386
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
387
        if (layer_idx not in mlp_only_layers) and (
388
389
390
391
392
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
393
        else:
394
395
396
397
398
399
400
401
402
403
404
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
405
406
407
408
409

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
410
        residual: torch.Tensor | None,
411
    ) -> tuple[torch.Tensor, torch.Tensor]:
412
413
414
415
416
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
417
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
418
419
420
421
422
423
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
424
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
425
426
427
428
429
430
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Qwen3MoeModel(nn.Module):
zxy's avatar
zxy committed
431
432
433
434
435
436
437
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layer_type: type[torch.nn.Module] = Qwen3MoeDecoderLayer,
    ):
438
439
        super().__init__()

440
        config = vllm_config.model_config.hf_text_config
441
        quant_config = vllm_config.quant_config
442
        parallel_config = vllm_config.parallel_config
443
444
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
445
446
447

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
448
        self.config = config
449
        self.quant_config = quant_config
450
451
452
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
453
            quant_config=quant_config,
454
455
            prefix=f"{prefix}.embed_tokens",
        )
456
457
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
zxy's avatar
zxy committed
458
            lambda prefix: decoder_layer_type(vllm_config=vllm_config, prefix=prefix),
459
460
461
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
463
464
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
465
466
        # Track layers for auxiliary hidden state outputs (EAGLE3)
        self.aux_hidden_state_layers: tuple[int, ...] = ()
467

468
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
469
470
471
472
        return self.embed_tokens(input_ids)

    def forward(
        self,
473
        input_ids: torch.Tensor | None,
474
        positions: torch.Tensor,
475
476
477
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
478
479
480
481
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
482
                hidden_states = self.embed_input_ids(input_ids)
483
484
485
486
487
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
488
489
490
491
492
493
494
495
496
497
498
499

        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            # Collect auxiliary hidden states if specified
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_state = (
                    hidden_states + residual if residual is not None else hidden_states
                )
                aux_hidden_states.append(aux_hidden_state)
500
            hidden_states, residual = layer(positions, hidden_states, residual)
501

502
        if not get_pp_group().is_last_rank:
503
504
505
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
506
        hidden_states, _ = self.norm(hidden_states, residual)
507
508
509
510

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
511
512
        return hidden_states

513
514
515
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
516
        return SharedFusedMoE.make_expert_params_mapping(
517
            self,
518
519
520
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
521
            num_experts=self.config.num_experts,
522
523
            num_redundant_experts=self.num_redundant_experts,
        )
524

525
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
526
527
528
529
530
531
532
533
534
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

535
        # Skip loading extra parameters for GPTQ/modelopt models.
536
537
538
539
540
541
542
543
544
545
546
547
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".k_scale",
            "_k_scale",
            ".v_scale",
            "_v_scale",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
548

549
        params_dict = dict(self.named_parameters())
550
        loaded_params: set[str] = set()
551
        expert_params_mapping = self.get_expert_mapping()
552
        for name, loaded_weight in weights:
553
554
555
556
557
558
559
560
561
562
563
564
565
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                assert loaded_weight.numel() == 1, (
                    f"KV scale numel {loaded_weight.numel()} != 1"
                )
                loaded_weight = loaded_weight.squeeze()
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
566
            for param_name, weight_name, shard_id in stacked_params_mapping:
567
568
569
570
571
572
573
574
575
576
577
578
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
579
580
581

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
582
                    continue
583

584
585
586
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
587
588
589
590
591
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
592
593
594
595
                if name not in params_dict:
                    continue

                param = params_dict[name]
596
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
597
598
599
600
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
601
602
                break
            else:
603
                is_expert_weight = False
604
605
606
607
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
608
609
610
611
612
613
614
615
616
617

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
618
                        continue
619

620
                    # Skip loading extra parameters for GPTQ/modelopt models.
621
622
623
624
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
625
                        continue
626
627
628
629
630

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
631
632
633
634
635
636
637
638
639
640
641
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
642
643
644
                    if success:
                        name = name_mapped
                        break
645
                else:
646
647
648
649
650
651
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

652
                    # Skip loading extra parameters for GPTQ/modelopt models.
653
                    if name.endswith(ignore_suffixes) and name not in params_dict:
654
655
656
657
658
659
660
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
661
662
                            ".kv_scale", ".attn.kv_scale"
                        )
663
664
                        if remapped_kv_scale_name not in params_dict:
                            logger.warning_once(
665
666
667
668
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
669
670
671
672
                            continue
                        else:
                            name = remapped_kv_scale_name
                    param = params_dict[name]
673
674
675
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
676
677
678
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
679
680


681
682
683
class Qwen3MoeForCausalLM(
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
684
685
686
687
688
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
689
        ]
690
    }
691
692
693
694
695

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
696
        config = vllm_config.model_config.hf_text_config
697
698
699
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
700
701
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
702
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
703
704
705
706
707
708
709
710
711
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
712
713
714
715
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
716
717
            self.model.make_empty_intermediate_tensors
        )
718

719
720
721
        # Set MoE hyperparameters
        self.expert_weights = []

722
        self.moe_layers = []
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
753
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
754
755
756
757
758
759
760
761
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

762
763
764
765
766
767
768
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

769
770
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
771
772
773

    def forward(
        self,
774
        input_ids: torch.Tensor | None,
775
        positions: torch.Tensor,
776
777
778
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
779
780
781
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
782
783
784
785
786
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
787
    ) -> torch.Tensor | None:
788
        logits = self.logits_processor(self.lm_head, hidden_states)
789
790
        return logits

791
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
792
        loader = AutoWeightsLoader(self)
793
        return loader.load_weights(weights)
794
795

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
796
        return self.model.get_expert_mapping()