qwen3_moe.py 28.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30
31
32
33
34
35

import torch
from torch import nn

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37
38
39
40
41
42
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
43
44
45
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
46
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
47
from vllm.model_executor.layers.layernorm import RMSNorm
48
49
50
51
52
53
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
54
55
56
57
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
59
60
    ParallelLMHead,
    VocabParallelEmbedding,
)
61
from vllm.model_executor.model_loader.weight_utils import (
62
63
64
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
65
from vllm.model_executor.models.utils import sequence_parallel_chunk
66
67
from vllm.sequence import IntermediateTensors

68
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
69
70
71
72
73
74
75
76
77
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
78
79
80
81
82
83
84
85
86
87

logger = init_logger(__name__)


class Qwen3MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
88
        quant_config: QuantizationConfig | None = None,
89
90
91
92
93
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
94
95
            hidden_size,
            [intermediate_size] * 2,
96
97
            bias=False,
            quant_config=quant_config,
98
99
100
101
102
103
104
105
106
107
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
108
        if hidden_act != "silu":
109
110
111
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
112
113
114
115
116
117
118
119
120
121
122
123
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen3MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
124
        vllm_config: VllmConfig,
125
126
127
        prefix: str = "",
    ):
        super().__init__()
128

129
        config = vllm_config.model_config.hf_text_config
130
131
132
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

133
134
        self.tp_size = get_tensor_model_parallel_world_size()

135
        self.ep_group = get_ep_group().device_group
136
        self.ep_rank = get_ep_group().rank_in_group
137
138
139
        self.ep_size = self.ep_group.size()
        self.n_routed_experts = config.num_experts

140
141
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

142
143
144
        if self.tp_size > config.num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
145
146
                f"the number of experts {config.num_experts}."
            )
147

148
149
        # Load balancing settings.
        vllm_config = get_current_vllm_config()
150
        eplb_config = vllm_config.parallel_config.eplb_config
151
        self.enable_eplb = parallel_config.enable_eplb
152
153

        self.n_logical_experts = self.n_routed_experts
154
        self.n_redundant_experts = eplb_config.num_redundant_experts
155
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
156
157
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )

        self.experts = FusedMoE(
            num_experts=self.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=True,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
175
            routing_method_type=RoutingMethodType.Renormalize,
176
177
178
179
180
181
182
183
184
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )
185
186

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
187
188
189
        assert hidden_states.dim() <= 2, (
            "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
        )
190
        is_input_1d = hidden_states.dim() == 1
191
        num_tokens, hidden_dim = hidden_states.shape
192
193
        hidden_states = hidden_states.view(-1, hidden_dim)

194
195
196
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

197
198
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
199
200
201
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
202

203
204
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
205
206
                final_hidden_states, 0
            )
207
208
            final_hidden_states = final_hidden_states[:num_tokens]

209
        # return to 1d if input is 1d
210
        return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
211
212
213
214
215
216
217
218


class Qwen3MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
219
        rope_parameters: dict[str, Any],
220
        max_position_embeddings: int = 8192,
221
        head_dim: int | None = None,
222
223
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
224
225
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
226
        prefix: str = "",
227
        dual_chunk_attention_config: dict[str, Any] | None = None,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
250
        self.dual_chunk_attention_config = dual_chunk_attention_config
251

252
253
254
255
256
257
258
259
260
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
261

262
263
264
265
266
267
268
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
269
270
271
272
273

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
274
            rope_parameters=rope_parameters,
275
276
277
278
279
280
281
282
283
284
285
286
287
            dual_chunk_attention_config=dual_chunk_attention_config,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
288
289
290
            }
            if dual_chunk_attention_config
            else {},
291
292
293
294
295
296
297
298
299
300
301
302
303
        )

        self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        # Add qk-norm
304
        q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
305
        q_by_head = self.q_norm(q_by_head)
306
307
        q = q_by_head.view(q.shape)

308
        k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
309
        k_by_head = self.k_norm(k_by_head)
310
311
312
313
314
315
316
317
        k = k_by_head.view(k.shape)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen3MoeDecoderLayer(nn.Module):
318
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
319
        super().__init__()
320

321
        config = vllm_config.model_config.hf_text_config
322
323
324
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

325
        self.hidden_size = config.hidden_size
326
327
328
329
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
330
331
332
333
        self.self_attn = Qwen3MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
334
            rope_parameters=config.rope_parameters,
335
336
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
337
338
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
339
340
341
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
342
            dual_chunk_attention_config=dual_chunk_attention_config,
343
344
345
346
        )

        # `mlp_only_layers` in the config.
        layer_idx = extract_layer_index(prefix)
347
348
349
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
350
        if (layer_idx not in mlp_only_layers) and (
351
352
353
354
355
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(
                vllm_config=vllm_config, prefix=f"{prefix}.mlp"
            )
356
        else:
357
358
359
360
361
362
363
364
365
366
367
            self.mlp = Qwen3MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
368
369
370
371
372

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
373
        residual: torch.Tensor | None,
374
    ) -> tuple[torch.Tensor, torch.Tensor]:
375
376
377
378
379
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
380
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
381
382
383
384
385
386
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
387
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
388
389
390
391
392
393
394
395
396
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Qwen3MoeModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

397
        config = vllm_config.model_config.hf_text_config
398
        quant_config = vllm_config.quant_config
399
        parallel_config = vllm_config.parallel_config
400
401
        eplb_config = parallel_config.eplb_config
        self.num_redundant_experts = eplb_config.num_redundant_experts
402
403
404

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
405
        self.config = config
406
407
408
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
409
            quant_config=quant_config,
410
411
            prefix=f"{prefix}.embed_tokens",
        )
412
413
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
414
            lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix),
415
416
417
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
418
419
420
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
421
422
        # Track layers for auxiliary hidden state outputs (EAGLE3)
        self.aux_hidden_state_layers: tuple[int, ...] = ()
423

424
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
425
426
427
428
429
430
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
431
432
433
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
434
435
436
437
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
438
                hidden_states = self.embed_input_ids(input_ids)
439
440
441
442
443
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
444
445
446
447
448
449
450
451
452
453
454
455

        aux_hidden_states = []
        for layer_idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            # Collect auxiliary hidden states if specified
            if layer_idx in self.aux_hidden_state_layers:
                aux_hidden_state = (
                    hidden_states + residual if residual is not None else hidden_states
                )
                aux_hidden_states.append(aux_hidden_state)
456
            hidden_states, residual = layer(positions, hidden_states, residual)
457

458
        if not get_pp_group().is_last_rank:
459
460
461
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
462
        hidden_states, _ = self.norm(hidden_states, residual)
463
464
465
466

        # Return auxiliary hidden states if collected
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
467
468
        return hidden_states

469
470
471
472
473
474
475
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
476
            num_experts=self.config.num_experts,
477
478
            num_redundant_experts=self.num_redundant_experts,
        )
479

480
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
481
482
483
484
485
486
487
488
489
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

490
        # Skip loading extra parameters for GPTQ/modelopt models.
491
492
493
494
495
496
497
498
499
500
501
502
        ignore_suffixes = (
            ".bias",
            "_bias",
            ".k_scale",
            "_k_scale",
            ".v_scale",
            "_v_scale",
            ".weight_scale",
            "_weight_scale",
            ".input_scale",
            "_input_scale",
        )
503

504
        params_dict = dict(self.named_parameters())
505
        loaded_params: set[str] = set()
506
        expert_params_mapping = self.get_expert_mapping()
507
        for name, loaded_weight in weights:
508
            for param_name, weight_name, shard_id in stacked_params_mapping:
509
510
511
512
513
514
515
516
517
518
519
520
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
521
522
523

                # Skip loading extra parameters for GPTQ/modelopt models.
                if name.endswith(ignore_suffixes) and name not in params_dict:
524
                    continue
525

526
527
528
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
529
530
531
532
533
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
534
535
536
537
                if name not in params_dict:
                    continue

                param = params_dict[name]
538
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
539
540
541
542
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
543
544
                break
            else:
545
                is_expert_weight = False
546
547
548
549
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
550
551
552
553
554
555
556
557
558
559

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
560
                        continue
561

562
                    # Skip loading extra parameters for GPTQ/modelopt models.
563
564
565
566
                    if (
                        name_mapped.endswith(ignore_suffixes)
                        and name_mapped not in params_dict
                    ):
567
                        continue
568
569
570
571
572

                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
573
574
575
576
577
578
579
580
581
582
583
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
584
585
586
                    if success:
                        name = name_mapped
                        break
587
                else:
588
589
590
591
592
593
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

594
                    # Skip loading extra parameters for GPTQ/modelopt models.
595
                    if name.endswith(ignore_suffixes) and name not in params_dict:
596
597
598
599
600
601
602
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
603
604
                            ".kv_scale", ".attn.kv_scale"
                        )
605
606
                        if remapped_kv_scale_name not in params_dict:
                            logger.warning_once(
607
608
609
610
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
611
612
613
614
                            continue
                        else:
                            name = remapped_kv_scale_name
                    param = params_dict[name]
615
616
617
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
618
619
620
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
621
622


623
624
625
class Qwen3MoeForCausalLM(
    nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
626
627
628
629
630
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
631
        ]
632
    }
633
634
635
636
637

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
638
        config = vllm_config.model_config.hf_text_config
639
640
641
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
642
643
        # Only perform the following mapping when Qwen3MoeMLP exists
        if getattr(config, "mlp_only_layers", []):
644
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
645
646
647
648
649
650
651
652
653
        self.model = Qwen3MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
654
655
656
657
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
658
659
            self.model.make_empty_intermediate_tensors
        )
660

661
662
663
        # Set MoE hyperparameters
        self.expert_weights = []

664
        self.moe_layers = []
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        example_layer = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Qwen3MoeDecoderLayer)
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                example_layer = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_layer is None:
            raise RuntimeError("No Qwen3MoE layer found in the model.layers.")

        self.num_moe_layers = len(self.moe_layers)
        self.num_expert_groups = 1
        self.num_shared_experts = 0
        self.num_logical_experts = example_layer.n_logical_experts
        self.num_physical_experts = example_layer.n_physical_experts
        self.num_local_physical_experts = example_layer.n_local_physical_experts
        self.num_routed_experts = example_layer.n_routed_experts
        self.num_redundant_experts = example_layer.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
695
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
696
697
698
699
700
701
702
703
        for layer in self.model.layers:
            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

704
705
706
707
708
709
710
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

711
712
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
713
714
715
716
717

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
718
719
720
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
721
722
723
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
724
725
726
727
728
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
729
    ) -> torch.Tensor | None:
730
        logits = self.logits_processor(self.lm_head, hidden_states)
731
732
        return logits

733
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
734
        loader = AutoWeightsLoader(self)
735
        return loader.load_weights(weights)
736
737

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
738
        return self.model.get_expert_mapping()