qwen3_moe.py 29.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30
31

import torch
32
import torch.nn.functional as F
33
34
35
from torch import nn

from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
43
44
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
45
from vllm.model_executor.layers.attention import Attention
46
47
48
49
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    fused_moe_make_expert_params_mapping,
)
50
from vllm.model_executor.layers.layernorm import RMSNorm
51
52
53
54
55
56
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
57
58
59
60
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
61
62
63
    ParallelLMHead,
    VocabParallelEmbedding,
)
64
from vllm.model_executor.model_loader.weight_utils import (
65
66
67
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
68
from vllm.model_executor.models.utils import sequence_parallel_chunk
69
70
from vllm.sequence import IntermediateTensors

71
72
73
74
75
76
77
78
from .interfaces import (
    EagleModelMixin,
    MixtureOfExperts,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
79
80
81
82
83
84
85
86
87
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
88
89
90
91
92
93
94
95
96
97

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
98
        quant_config: QuantizationConfig | None = None,
99
        reduce_results: bool = True,
100
        expert_gate: torch.nn.Linear | None = None,
101
102
103
104
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
105
106
            hidden_size,
            [intermediate_size] * 2,
107
108
            bias=False,
            quant_config=quant_config,
109
110
111
112
113
114
115
116
117
118
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
119
        if hidden_act != "silu":
120
121
122
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
123
        self.act_fn = SiluAndMul()
124
        self.expert_gate = expert_gate
125
126
127

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
128
129
130
131
132
133
134
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)

        if self.expert_gate is not None:
            out = F.sigmoid(self.expert_gate(x)[0]) * out

        return out
135
136
137
138
139


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
140
        vllm_config: VllmConfig,
141
142
143
        prefix: str = "",
    ):
        super().__init__()
144

145
        config = vllm_config.model_config.hf_text_config
146
147
148
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

149
150
        self.tp_size = get_tensor_model_parallel_world_size()

151
        self.ep_group = get_ep_group().device_group
152
        self.ep_rank = get_ep_group().rank_in_group
153
154
155
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

156
157
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

158
159
160
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
161
162
                f"the number of experts {config.num_experts}."
            )
163

164
165
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
166
        eplb_config = vllm_config.parallel_config.eplb_config
167
        self.enable_eplb = parallel_config.enable_eplb
168
169

        self.n_logical_experts = self.n_routed_experts
170
        self.n_redundant_experts = eplb_config.num_redundant_experts
171
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
172
173
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

174
175
176
177
178
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )

        shared_expert_intermediate_size = getattr(
            config, "shared_expert_intermediate_size", 0
        )
        if shared_expert_intermediate_size > 0:
            self.shared_expert_gate = ReplicatedLinear(
                config.hidden_size,
                1,
                bias=False,
                quant_config=None,
                prefix=f"{prefix}.shared_expert_gate",
            )
            self.shared_expert = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
                prefix=f"{prefix}.shared_expert",
            )
        else:
            self.shared_expert_gate = None
            self.shared_expert = None

211
        self.experts = FusedMoE(
212
213
            shared_experts=self.shared_expert,
            gate=self.gate,
214
215
216
217
218
219
220
221
222
223
224
225
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )

226
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
227
228
229
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
230
        is_input_1d = hidden_states.dim() == 1
231
        num_tokens, hidden_dim = hidden_states.shape
232
233
        hidden_states = hidden_states.view(-1, hidden_dim)

234
235
236
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

237
238
239
240
241
242
243
244
245
246
247
248
249
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # Actually this will be dead code, since we always pass gate into
            # FusedMoE in the current implementation. But we keep this code
            # here for clarity and future flexibility.
            router_logits, _ = self.gate(hidden_states)
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
250

251
252
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
253
254
                final_hidden_states, 0
            )
255
256
            final_hidden_states = final_hidden_states[:num_tokens]

257
        # return to 1d if input is 1d
258
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
259
260
261
262
263
264
265
266


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
267
        rope_parameters: dict[str, Any],
268
        max_position_embeddings: int = 8192,
269
        head_dim: int | None = None,
270
271
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
272
273
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
274
        prefix: str = "",
275
        dual_chunk_attention_config: dict[str, Any] | None = None,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
298
        self.dual_chunk_attention_config = dual_chunk_attention_config
299

300
301
302
303
304
305
306
307
308
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
309

310
311
312
313
314
315
316
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
317
318
319
320

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
321
            rope_parameters=rope_parameters,
322
323
324
325
326
327
328
329
330
331
332
333
334
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
335
336
337
            }
            if dual_chunk_attention_config
            else {},
338
339
340
341
342
343
344
345
346
347
348
349
350
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
351
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
352
        q_by_head = self.q_norm(q_by_head)
353
354
        q = q_by_head.view(q.shape)

355
        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
356
        k_by_head = self.k_norm(k_by_head)
357
358
359
360
361
362
363
364
        k = k_by_head.view(k.shape)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):
365
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
366
        super().__init__()
367

368
        config = vllm_config.model_config.hf_text_config
369
370
371
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

372
        self.hidden_size = config.hidden_size
373
374
375
376
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
377
378
379
380
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
381
            rope_parameters=config.rope_parameters,
382
383
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
384
385
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
386
387
388
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
389
            dual_chunk_attention_config=dual_chunk_attention_config,
390
391
392
393
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
394
395
396
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
397
        if (layer_idx not in mlp_only_layers) and (
398
399
400
401
402
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
403
        else:
404
405
406
407
408
409
410
411
412
413
414
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
415
416
417
418
419

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
420
        residual: torch.Tensor | None,
421
    ) -> tuple[torch.Tensor, torch.Tensor]:
422
423
424
425
426
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
427
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
428
429
430
431
432
433
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
434
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
435
436
437
438
439
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
440
class Qwen3MoeModel(nn.Module, EagleModelMixin):
zxy's avatar
zxy committed
441
442
443
444
445
446
447
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layer_type: type[torch.nn.Module] = Qwen3MoeDecoderLayer,
    ):
448
449
        super().__init__()

450
        config = vllm_config.model_config.hf_text_config
451
        quant_config = vllm_config.quant_config
452
        parallel_config = vllm_config.parallel_config
453
454
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
455
456

        self.vocab_size = config.vocab_size
457
        self.config = config
458
        self.quant_config = quant_config
459
460
461
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
462
            quant_config=quant_config,
463
464
            prefix=f"{prefix}.embed_tokens",
        )
465
466
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
zxy's avatar
zxy committed
467
            lambda prefix: decoder_layer_type(vllm_config=vllm_config, prefix=prefix),
468
469
470
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
472
473
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
474

475
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
476
477
478
479
        return self.embed_tokens(input_ids)

    def forward(
        self,
480
        input_ids: torch.Tensor | None,
481
        positions: torch.Tensor,
482
483
484
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
485
486
487
488
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
489
                hidden_states = self.embed_input_ids(input_ids)
490
491
492
493
494
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
495

496
497
498
        aux_hidden_states = self._maybe_add_hidden_state(
            [], self.start_layer, hidden_states, residual
        )
499
500
501
502
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
503
            hidden_states, residual = layer(positions, hidden_states, residual)
504
505
506
            self._maybe_add_hidden_state(
                aux_hidden_states, layer_idx + 1, hidden_states, residual
            )
507

508
        if not get_pp_group().is_last_rank:
509
510
511
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
512
        hidden_states, _ = self.norm(hidden_states, residual)
513
514
515
516

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
517
518
        return hidden_states

519
520
521
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
522
        return fused_moe_make_expert_params_mapping(
523
            self,
524
525
526
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
527
            num_experts=self.config.num_experts,
528
529
            num_redundant_experts=self.num_redundant_experts,
        )
530

531
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
532
533
534
535
536
537
538
539
540
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

541
        # Skip loading extra parameters for GPTQ/modelopt models.
542
543
544
545
546
547
548
549
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
550

551
        params_dict = dict(self.named_parameters())
552
        loaded_params: set[str] = set()
553
        expert_params_mapping = self.get_expert_mapping()
554
        for name, loaded_weight in weights:
555
556
557
558
559
560
561
562
563
564
565
566
567
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                assert loaded_weight.numel() == 1, (
                    f"KV scale numel {loaded_weight.numel()} != 1"
                )
                loaded_weight = loaded_weight.squeeze()
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
568
569
570
571
            if "scale" in name or "zero_point" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
572
            for param_name, weight_name, shard_id in stacked_params_mapping:
573
574
575
576
577
578
579
580
581
582
583
584
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
585
586
587

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
588
                    continue
589

590
591
592
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
593
594
595
596
597
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
598
599
600
601
                if name not in params_dict:
                    continue

                param = params_dict[name]
602
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
603
604
605
606
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
607
608
                break
            else:
609
                is_expert_weight = False
610
611
612
613
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
614
615
616
617
618
619
620
621
622
623

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
624
                        continue
625

626
                    # Skip loading extra parameters for GPTQ/modelopt models.
627
628
629
630
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
631
                        continue
632
633
634
635
636

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
637
638
639
640
641
642
643
644
645
646
647
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
648
649
650
                    if success:
                        name = name_mapped
                        break
651
                else:
652
653
654
655
656
657
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

658
                    # Skip loading extra parameters for GPTQ/modelopt models.
659
                    if name.endswith(ignore_suffixes) and name not in params_dict:
660
661
662
663
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
664
665
                    if name not in params_dict:
                        continue
666
                    param = params_dict[name]
667
668
669
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
670
671
672
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
673
674


675
class Qwen3MoeForCausalLM(
676
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
677
):
678
679
680
681
682
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
683
        ]
684
    }
685

686
687
688
689
690
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

691
692
693
694
    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
695
        config = vllm_config.model_config.hf_text_config
696
697
698
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
699
700
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
701
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
702
703
704
705
706
707
708
709
710
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
711
712
713
714
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
715
716
            self.model.make_empty_intermediate_tensors
        )
717

718
719
720
        # Set MoE hyperparameters
        self.expert_weights = []

721
        self.moe_layers = []
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
752
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
753
754
755
756
757
758
759
760
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

761
762
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
763
764
765

    def forward(
        self,
766
        input_ids: torch.Tensor | None,
767
        positions: torch.Tensor,
768
769
770
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
771
772
773
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
774
775
776
777
778
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
779
    ) -> torch.Tensor | None:
780
        logits = self.logits_processor(self.lm_head, hidden_states)
781
782
        return logits

783
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
784
        loader = AutoWeightsLoader(self)
785
        return loader.load_weights(weights)
786
787

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
788
        return self.model.get_expert_mapping()