qwen3_moe.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any, Optional, Union
30
31
32
33
34
35

import torch
from torch import nn

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
43
44
45
46
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
47
48
49
50
51
52
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
53
54
55
56
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
58
59
    ParallelLMHead,
    VocabParallelEmbedding,
)
60
from vllm.model_executor.model_loader.weight_utils import (
61
62
63
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
64
from vllm.model_executor.models.utils import sequence_parallel_chunk
65
66
from vllm.sequence import IntermediateTensors

67
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
68
69
70
71
72
73
74
75
76
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
93
94
            hidden_size,
            [intermediate_size] * 2,
95
96
            bias=False,
            quant_config=quant_config,
97
98
99
100
101
102
103
104
105
106
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
107
        if hidden_act != "silu":
108
109
110
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
111
112
113
114
115
116
117
118
119
120
121
122
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
123
        vllm_config: VllmConfig,
124
125
126
        prefix: str = "",
    ):
        super().__init__()
127

128
        config = vllm_config.model_config.hf_text_config
129
130
131
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

132
133
        self.tp_size = get_tensor_model_parallel_world_size()

134
135
136
137
138
        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

139
140
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

141
142
143
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
144
145
                f"the number of experts {config.num_experts}."
            )
146

147
148
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
149
        eplb_config = vllm_config.parallel_config.eplb_config
150
        self.enable_eplb = parallel_config.enable_eplb
151
152

        self.n_logical_experts = self.n_routed_experts
153
        self.n_redundant_experts = eplb_config.num_redundant_experts
154
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
155
156
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.experts = FusedMoE(
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=True,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )
183
184

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
185
186
187
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
188
        is_input_1d = hidden_states.dim() == 1
189
        num_tokens, hidden_dim = hidden_states.shape
190
191
        hidden_states = hidden_states.view(-1, hidden_dim)

192
193
194
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

195
196
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
197
198
199
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
200

201
202
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
203
204
                final_hidden_states, 0
            )
205
206
            final_hidden_states = final_hidden_states[:num_tokens]

207
        # return to 1d if input is 1d
208
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
209
210
211
212
213
214
215
216
217


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
218
        rope_scaling: Optional[dict[str, Any]] = None,
219
220
221
222
223
224
225
        max_position_embeddings: int = 8192,
        head_dim: Optional[int] = None,
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
226
        dual_chunk_attention_config: Optional[dict[str, Any]] = None,
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
250
        self.dual_chunk_attention_config = dual_chunk_attention_config
251

252
253
254
255
256
257
258
259
260
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
261

262
263
264
265
266
267
268
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
269
270
271
272
273
274
275

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
276
277
278
279
280
281
282
283
284
285
286
287
288
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
289
290
291
            }
            if dual_chunk_attention_config
            else {},
292
293
294
295
296
297
298
299
300
301
302
303
304
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
305
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
306
        q_by_head = self.q_norm(q_by_head)
307
308
        q = q_by_head.view(q.shape)

309
        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
310
        k_by_head = self.k_norm(k_by_head)
311
312
313
314
315
316
317
318
        k = k_by_head.view(k.shape)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):
319
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
320
        super().__init__()
321

322
        config = vllm_config.model_config.hf_text_config
323
324
325
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

326
327
328
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
329
330
331
332
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
333
334
335
336
337
338
339
340
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
341
342
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
343
344
345
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
346
            dual_chunk_attention_config=dual_chunk_attention_config,
347
348
349
350
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
351
352
353
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
354
        if (layer_idx not in mlp_only_layers) and (
355
356
357
358
359
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
360
        else:
361
362
363
364
365
366
367
368
369
370
371
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
372
373
374
375
376
377

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
378
    ) -> tuple[torch.Tensor, torch.Tensor]:
379
380
381
382
383
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
384
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
385
386
387
388
389
390
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
391
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
392
393
394
395
396
397
398
399
400
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Qwen3MoeModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

401
        config = vllm_config.model_config.hf_text_config
402
        quant_config = vllm_config.quant_config
403
        parallel_config = vllm_config.parallel_config
404
405
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
406
407
408

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
409
        self.config = config
410
411
412
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
413
            quant_config=quant_config,
414
415
            prefix=f"{prefix}.embed_tokens",
        )
416
417
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
418
            lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix),
419
420
421
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
422
423
424
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
425
426
        # Track layers for auxiliary hidden state outputs (EAGLE3)
        self.aux_hidden_state_layers: tuple[int, ...] = ()
427
428
429
430
431
432
433
434
435
436

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
437
438
439
    ) -> Union[
        torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]]
    ]:
440
441
442
443
444
445
446
447
448
449
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
450
451
452
453
454
455
456
457
458
459
460
461

        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            # Collect auxiliary hidden states if specified
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_state = (
                    hidden_states + residual if residual is not None else hidden_states
                )
                aux_hidden_states.append(aux_hidden_state)
462
            hidden_states, residual = layer(positions, hidden_states, residual)
463

464
        if not get_pp_group().is_last_rank:
465
466
467
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
468
        hidden_states, _ = self.norm(hidden_states, residual)
469
470
471
472

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
473
474
        return hidden_states

475
476
477
478
479
480
481
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
482
            num_experts=self.config.num_experts,
483
484
            num_redundant_experts=self.num_redundant_experts,
        )
485

486
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
487
488
489
490
491
492
493
494
495
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

496
        # Skip loading extra parameters for GPTQ/modelopt models.
497
498
499
500
501
502
503
504
505
506
507
508
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".k_scale",
            "_k_scale",
            ".v_scale",
            "_v_scale",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
509

510
        params_dict = dict(self.named_parameters())
511
        loaded_params: set[str] = set()
512
        expert_params_mapping = self.get_expert_mapping()
513
        for name, loaded_weight in weights:
514
            for param_name, weight_name, shard_id in stacked_params_mapping:
515
516
517
518
519
520
521
522
523
524
525
526
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
527
528
529

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
530
                    continue
531

532
533
534
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
535
536
537
538
539
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
540
541
542
543
                if name not in params_dict:
                    continue

                param = params_dict[name]
544
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
545
546
547
548
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
549
550
                break
            else:
551
                is_expert_weight = False
552
553
554
555
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
556
557
558
559
560
561
562
563
564
565

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
566
                        continue
567

568
                    # Skip loading extra parameters for GPTQ/modelopt models.
569
570
571
572
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
573
                        continue
574
575
576
577
578

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
579
580
581
582
583
584
585
586
587
588
589
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
590
591
592
                    if success:
                        name = name_mapped
                        break
593
                else:
594
595
596
597
598
599
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

600
                    # Skip loading extra parameters for GPTQ/modelopt models.
601
                    if name.endswith(ignore_suffixes) and name not in params_dict:
602
603
604
605
606
607
608
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
609
610
                            ".kv_scale", ".attn.kv_scale"
                        )
611
612
                        if remapped_kv_scale_name not in params_dict:
                            logger.warning_once(
613
614
615
616
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
617
618
619
620
                            continue
                        else:
                            name = remapped_kv_scale_name
                    param = params_dict[name]
621
622
623
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
624
625
626
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
627
628


629
630
631
class Qwen3MoeForCausalLM(
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
632
633
634
635
636
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
637
        ]
638
    }
639
640
641
642
643

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
644
        config = vllm_config.model_config.hf_text_config
645
646
647
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
648
649
650
651
652
653
654
655
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
            self.packed_modules_mapping["gate_up_proj"] = (
                [
                    "gate_proj",
                    "up_proj",
                ],
            )
656
657
658
659
660
661
662
663
664
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
665
666
667
668
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
669
670
            self.model.make_empty_intermediate_tensors
        )
671

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        # Set MoE hyperparameters
        self.expert_weights = []

        self.moe_layers: list[FusedMoE] = []
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
722
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
723
724
725
726
727
728
729
730
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

731
732
733
734
735
736
737
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

738
739
740
741
742
743
744
745
746
747
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
748
749
750
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
751
752
753
754
755
756
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
757
        logits = self.logits_processor(self.lm_head, hidden_states)
758
759
        return logits

760
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
761
        loader = AutoWeightsLoader(self)
762
        return loader.load_weights(weights)
763
764

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
765
        return self.model.get_expert_mapping()