qwen3_moe.py 29.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30
31

import torch
32
import torch.nn.functional as F
33
34
35
from torch import nn

from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
43
44
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
45
from vllm.model_executor.layers.attention import Attention
46
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
47
from vllm.model_executor.layers.layernorm import RMSNorm
48
49
50
51
52
53
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
54
55
56
57
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
59
60
    ParallelLMHead,
    VocabParallelEmbedding,
)
61
from vllm.model_executor.model_loader.weight_utils import (
62
63
64
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
65
from vllm.model_executor.models.utils import sequence_parallel_chunk
66
67
from vllm.sequence import IntermediateTensors

68
69
70
71
72
73
74
75
from .interfaces import (
    EagleModelMixin,
    MixtureOfExperts,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
76
77
78
79
80
81
82
83
84
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
85
86
87
88
89
90
91
92
93
94

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
95
        quant_config: QuantizationConfig | None = None,
96
        reduce_results: bool = True,
97
        expert_gate: torch.nn.Linear | None = None,
98
99
100
101
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
102
103
            hidden_size,
            [intermediate_size] * 2,
104
105
            bias=False,
            quant_config=quant_config,
106
107
108
109
110
111
112
113
114
115
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
116
        if hidden_act != "silu":
117
118
119
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
120
        self.act_fn = SiluAndMul()
121
        self.expert_gate = expert_gate
122
123
124

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
125
126
127
128
129
130
131
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)

        if self.expert_gate is not None:
            out = F.sigmoid(self.expert_gate(x)[0]) * out

        return out
132
133
134
135
136


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
137
        vllm_config: VllmConfig,
138
139
140
        prefix: str = "",
    ):
        super().__init__()
141

142
        config = vllm_config.model_config.hf_text_config
143
144
145
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

146
147
        self.tp_size = get_tensor_model_parallel_world_size()

148
        self.ep_group = get_ep_group().device_group
149
        self.ep_rank = get_ep_group().rank_in_group
150
151
152
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

153
154
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

155
156
157
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
158
159
                f"the number of experts {config.num_experts}."
            )
160

161
162
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
163
        eplb_config = vllm_config.parallel_config.eplb_config
164
        self.enable_eplb = parallel_config.enable_eplb
165
166

        self.n_logical_experts = self.n_routed_experts
167
        self.n_redundant_experts = eplb_config.num_redundant_experts
168
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
169
170
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

171
172
173
174
175
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )

        shared_expert_intermediate_size = getattr(
            config, "shared_expert_intermediate_size", 0
        )
        if shared_expert_intermediate_size > 0:
            self.shared_expert_gate = ReplicatedLinear(
                config.hidden_size,
                1,
                bias=False,
                quant_config=None,
                prefix=f"{prefix}.shared_expert_gate",
            )
            self.shared_expert = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
                prefix=f"{prefix}.shared_expert",
            )
        else:
            self.shared_expert_gate = None
            self.shared_expert = None

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
            gate=self.gate,
211
212
213
214
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
215
            reduce_results=False,
216
217
218
219
220
221
222
223
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )

224
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225
226
227
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
228
        is_input_1d = hidden_states.dim() == 1
229
        num_tokens, hidden_dim = hidden_states.shape
230
231
        hidden_states = hidden_states.view(-1, hidden_dim)

232
233
234
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

235
236
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
237
        shared_out, fused_out = self.experts(
238
239
            hidden_states=hidden_states, router_logits=router_logits
        )
240
241
242
        final_hidden_states = (
            shared_out + fused_out if shared_out is not None else fused_out
        )
243

244
245
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
246
247
                final_hidden_states, 0
            )
248
            final_hidden_states = final_hidden_states[:num_tokens]
249
250
251
252
        elif self.tp_size > 1:
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
                final_hidden_states
            )
253

254
        # return to 1d if input is 1d
255
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
256
257
258
259
260
261
262
263


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
264
        rope_parameters: dict[str, Any],
265
        max_position_embeddings: int = 8192,
266
        head_dim: int | None = None,
267
268
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
269
270
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
271
        prefix: str = "",
272
        dual_chunk_attention_config: dict[str, Any] | None = None,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
295
        self.dual_chunk_attention_config = dual_chunk_attention_config
296

297
298
299
300
301
302
303
304
305
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
306

307
308
309
310
311
312
313
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
314
315
316
317

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
318
            rope_parameters=rope_parameters,
319
320
321
322
323
324
325
326
327
328
329
330
331
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
332
333
334
            }
            if dual_chunk_attention_config
            else {},
335
336
337
338
339
340
341
342
343
344
345
346
347
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
348
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
349
        q_by_head = self.q_norm(q_by_head)
350
351
        q = q_by_head.view(q.shape)

352
        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
353
        k_by_head = self.k_norm(k_by_head)
354
355
356
357
358
359
360
361
        k = k_by_head.view(k.shape)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):
362
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
363
        super().__init__()
364

365
        config = vllm_config.model_config.hf_text_config
366
367
368
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

369
        self.hidden_size = config.hidden_size
370
371
372
373
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
374
375
376
377
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
378
            rope_parameters=config.rope_parameters,
379
380
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
381
382
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
383
384
385
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
386
            dual_chunk_attention_config=dual_chunk_attention_config,
387
388
389
390
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
391
392
393
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
394
        if (layer_idx not in mlp_only_layers) and (
395
396
397
398
399
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
400
        else:
401
402
403
404
405
406
407
408
409
410
411
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
412
413
414
415
416

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
417
        residual: torch.Tensor | None,
418
    ) -> tuple[torch.Tensor, torch.Tensor]:
419
420
421
422
423
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
424
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
425
426
427
428
429
430
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
431
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
432
433
434
435
436
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
437
class Qwen3MoeModel(nn.Module, EagleModelMixin):
zxy's avatar
zxy committed
438
439
440
441
442
443
444
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layer_type: type[torch.nn.Module] = Qwen3MoeDecoderLayer,
    ):
445
446
        super().__init__()

447
        config = vllm_config.model_config.hf_text_config
448
        quant_config = vllm_config.quant_config
449
        parallel_config = vllm_config.parallel_config
450
451
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
452
453

        self.vocab_size = config.vocab_size
454
        self.config = config
455
        self.quant_config = quant_config
456
457
458
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
459
            quant_config=quant_config,
460
461
            prefix=f"{prefix}.embed_tokens",
        )
462
463
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
zxy's avatar
zxy committed
464
            lambda prefix: decoder_layer_type(vllm_config=vllm_config, prefix=prefix),
465
466
467
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
468
469
470
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
471

472
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
473
474
475
476
        return self.embed_tokens(input_ids)

    def forward(
        self,
477
        input_ids: torch.Tensor | None,
478
        positions: torch.Tensor,
479
480
481
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
482
483
484
485
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
486
                hidden_states = self.embed_input_ids(input_ids)
487
488
489
490
491
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
492

493
494
495
        aux_hidden_states = self._maybe_add_hidden_state(
            [], self.start_layer, hidden_states, residual
        )
496
497
498
499
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
500
            hidden_states, residual = layer(positions, hidden_states, residual)
501
502
503
            self._maybe_add_hidden_state(
                aux_hidden_states, layer_idx + 1, hidden_states, residual
            )
504

505
        if not get_pp_group().is_last_rank:
506
507
508
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
509
        hidden_states, _ = self.norm(hidden_states, residual)
510
511
512
513

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
514
515
        return hidden_states

516
517
518
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
519
        return SharedFusedMoE.make_expert_params_mapping(
520
            self,
521
522
523
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
524
            num_experts=self.config.num_experts,
525
526
            num_redundant_experts=self.num_redundant_experts,
        )
527

528
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
529
530
531
532
533
534
535
536
537
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

538
        # Skip loading extra parameters for GPTQ/modelopt models.
539
540
541
542
543
544
545
546
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
547

548
        params_dict = dict(self.named_parameters())
549
        loaded_params: set[str] = set()
550
        expert_params_mapping = self.get_expert_mapping()
551
        for name, loaded_weight in weights:
552
553
554
555
556
557
558
559
560
561
562
563
564
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                assert loaded_weight.numel() == 1, (
                    f"KV scale numel {loaded_weight.numel()} != 1"
                )
                loaded_weight = loaded_weight.squeeze()
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
565
566
567
568
            if "scale" in name or "zero_point" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
569
            for param_name, weight_name, shard_id in stacked_params_mapping:
570
571
572
573
574
575
576
577
578
579
580
581
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
582
583
584

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
585
                    continue
586

587
588
589
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
590
591
592
593
594
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
595
596
597
598
                if name not in params_dict:
                    continue

                param = params_dict[name]
599
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
600
601
602
603
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
604
605
                break
            else:
606
                is_expert_weight = False
607
608
609
610
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
611
612
613
614
615
616
617
618
619
620

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
621
                        continue
622

623
                    # Skip loading extra parameters for GPTQ/modelopt models.
624
625
626
627
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
628
                        continue
629
630
631
632
633

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
634
635
636
637
638
639
640
641
642
643
644
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
645
646
647
                    if success:
                        name = name_mapped
                        break
648
                else:
649
650
651
652
653
654
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

655
                    # Skip loading extra parameters for GPTQ/modelopt models.
656
                    if name.endswith(ignore_suffixes) and name not in params_dict:
657
658
659
660
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
661
662
                    if name not in params_dict:
                        continue
663
                    param = params_dict[name]
664
665
666
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
667
668
669
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
670
671


672
class Qwen3MoeForCausalLM(
673
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
674
):
675
676
677
678
679
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
680
        ]
681
    }
682

683
684
685
686
687
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

688
689
690
691
    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
692
        config = vllm_config.model_config.hf_text_config
693
694
695
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
696
697
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
698
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
699
700
701
702
703
704
705
706
707
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
708
709
710
711
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
712
713
            self.model.make_empty_intermediate_tensors
        )
714

715
716
717
        # Set MoE hyperparameters
        self.expert_weights = []

718
        self.moe_layers = []
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
749
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
750
751
752
753
754
755
756
757
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

758
759
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
760
761
762

    def forward(
        self,
763
        input_ids: torch.Tensor | None,
764
        positions: torch.Tensor,
765
766
767
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
768
769
770
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
771
772
773
774
775
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
776
    ) -> torch.Tensor | None:
777
        logits = self.logits_processor(self.lm_head, hidden_states)
778
779
        return logits

780
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
781
        loader = AutoWeightsLoader(self)
782
        return loader.load_weights(weights)
783
784

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
785
        return self.model.get_expert_mapping()