ernie45_moe.py 27.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only ErineMoE model compatible with HuggingFace weights."""
25

26
27
import typing
from collections.abc import Callable, Iterable
28
from itertools import islice
29
from typing import Any
30
31
32
33
34
35

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.compilation.decorators import support_torch_compile
36
37
38
39
40
41
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
42
43
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
44
from vllm.model_executor.layers.attention import Attention
45
46
47
48
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    fused_moe_make_expert_params_mapping,
)
49
from vllm.model_executor.layers.layernorm import RMSNorm
50
51
52
53
54
55
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
56
57
58
59
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
60
61
62
    ParallelLMHead,
    VocabParallelEmbedding,
)
63
from vllm.model_executor.model_loader.weight_utils import (
64
65
66
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
67
from vllm.sequence import IntermediateTensors
68
from vllm.transformers_utils.config import set_default_rope_theta
69

70
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
71
72
73
74
75
76
77
78
79
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
80
81
82
83
84
85
86
87
88
89
90

logger = init_logger(__name__)


class Ernie4_5_MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        use_bias: bool = False,
91
        quant_config: QuantizationConfig | None = None,
92
93
94
95
96
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
97
98
99
100
101
102
103
104
105
            hidden_size,
            [intermediate_size] * 2,
            bias=use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
106
107
            bias=use_bias,
            quant_config=quant_config,
108
109
110
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
111
        if hidden_act != "silu":
112
113
114
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
115
116
117
118
119
120
121
122
123
124
125
126
127
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Ernie4_5_MoeMoE(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
128
        quant_config: QuantizationConfig | None = None,
129
        prefix: str = "",
130
        enable_eplb: bool = False,
131
132
133
134
135
136
    ):
        super().__init__()

        layer_idx = extract_layer_index(prefix)
        self.layer_idx = layer_idx
        self.tp_size = get_tensor_model_parallel_world_size()
137
138
139

        self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None)
        self.ep_group = get_ep_group().device_group
140
        self.ep_rank = get_ep_group().rank_in_group
141
142
143
144
145
146
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.moe_num_experts
        self.n_shared_experts: int = self.moe_num_shared_experts

        # Load balancing settings.
        vllm_config = get_current_vllm_config()
147
        eplb_config = vllm_config.parallel_config.eplb_config
148
149
        self.enable_eplb = enable_eplb

150
        self.n_redundant_experts = eplb_config.num_redundant_experts
151
152
153
154
155
156
157
        self.n_logical_experts = self.n_routed_experts
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
158
        self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0
159
160
161
162

        if self.tp_size > config.moe_num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
163
164
165
166
167
168
169
170
171
172
173
                f"the number of experts {config.moe_num_experts}."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.moe_num_experts,
            bias=False,
            params_dtype=torch.float32,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
174

175
        self.gate.e_score_correction_bias = nn.Parameter(
176
177
            torch.empty(config.moe_num_experts, dtype=torch.float32)
        )
178

179
        if self.has_shared_experts:
180
181
182
            intermediate_size = (
                config.moe_intermediate_size * config.moe_num_shared_experts
            )
183
184
185
186
187
188
            self.shared_experts = Ernie4_5_MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.shared_experts",
189
                reduce_results=False,
190
            )
191
192
193
        else:
            self.shared_experts = None

194
        self.experts = FusedMoE(
195
196
197
198
199
200
201
202
203
            shared_experts=self.shared_experts,
            num_experts=config.moe_num_experts,
            top_k=config.moe_k,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            e_score_correction_bias=self.gate.e_score_correction_bias,
204
205
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
206
            router_logits_dtype=torch.float32,
207
        )
208
209
210
211
212
213

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
        hidden_states = hidden_states.view(-1, hidden_dim)

214
        router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
215

216
217
218
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
219
220
221
222
223
224
225
226
227
228

        return final_hidden_states.view(orig_shape)


class Ernie4_5_MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
229
        rope_parameters: dict[str, Any],
230
        head_dim: int | None = None,
231
232
233
        max_position_embeddings: int = 131072,
        rms_norm_eps: float = 1e-05,
        qkv_bias: bool = False,
234
235
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        prefix: str = "",
    ) -> None:
        super().__init__()
        layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0
        self.layer_idx = layer_idx
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

264
265
266
267
268
269
270
271
272
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
273

274
275
276
277
278
279
280
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
281
282
283
284

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
285
            rope_parameters=rope_parameters,
286
287
            is_neox_style=False,
        )
288
289
290
291
292
293
294
295
296
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)

        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)

        # Attention
        attn_output = self.attn(q, k, v)
        # Output projection
        output, _ = self.o_proj(attn_output)
        return output


class Ernie4_5_MoeDecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
319
320
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
321
        prefix: str = "",
322
        enable_eplb: bool = False,
323
324
325
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
326
        set_default_rope_theta(config, default_theta=500000)
327
        max_position_embeddings = getattr(config, "max_position_embeddings", 131072)
328
329
330
331
        self.self_attn = Ernie4_5_MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
332
            head_dim=getattr(config, "head_dim", None),
333
            rope_parameters=config.rope_parameters,
334
335
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
336
            qkv_bias=getattr(config, "use_bias", False),
337
338
339
340
341
342
343
344
345
346
347
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )

        layer_idx = extract_layer_index(prefix)
        self.layer_idx = layer_idx

        # MoE
        moe_num_experts = getattr(config, "moe_num_experts", 0)
        moe_layer_start_index = getattr(config, "moe_layer_start_index", 0)
348
349
350
        moe_layer_end_index = getattr(
            config, "moe_layer_end_index", config.num_hidden_layers - 1
        )
351
352
353
        moe_layer_interval = getattr(config, "moe_layer_interval", 1)
        use_moe = getattr(config, "use_moe", moe_num_experts > 0)

354
355
356
357
358
359
360
        if (
            use_moe
            and ((layer_idx + 1) % moe_layer_interval == 0)
            and layer_idx >= moe_layer_start_index
            and layer_idx <= moe_layer_end_index
        ):
            self.mlp = Ernie4_5_MoeMoE(
361
362
363
364
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
                enable_eplb=enable_eplb,
365
            )
366
367
368
369
370
        else:
            self.mlp = Ernie4_5_MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
371
                use_bias=getattr(config, "use_bias", False),
372
                quant_config=quant_config,
373
374
                prefix=f"{prefix}.mlp",
            )
375

376
377
378
379
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
380
381
382
383
384

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
385
        residual: torch.Tensor | None,
386
387
388
389
390
391
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
392
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
393
394
395
396
397
398
399

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
400
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


@support_torch_compile
class Ernie4_5_MoeModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.vocab_size = config.vocab_size
        self.config = config
418
        parallel_config = vllm_config.parallel_config
419
        eplb_config = parallel_config.eplb_config
420
        enable_eplb = parallel_config.enable_eplb
421
422

        self.num_redundant_experts = eplb_config.num_redundant_experts
423
424
425
426
427
428

        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
429
430
                prefix=f"{prefix}.embed_tokens",
            )
431
432
433
434
435
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
436
437
438
439
440
            lambda prefix: Ernie4_5_MoeDecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
441
                enable_eplb=enable_eplb,
442
            ),
443
444
445
446
447
448
449
450
            prefix=f"{prefix}.layers",
        )

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

451
452
453
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
454

455
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
456
457
458
459
        return self.embed_tokens(input_ids)

    def forward(
        self,
460
        input_ids: torch.Tensor | None,
461
        positions: torch.Tensor,
462
463
464
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
465
466
467
468
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
469
                hidden_states = self.embed_input_ids(input_ids)
470
471
472
473
474
475
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

476
        for layer in islice(self.layers, self.start_layer, self.end_layer):
477
478
479
            hidden_states, residual = layer(positions, hidden_states, residual)

        if not get_pp_group().is_last_rank:
480
481
482
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
483
484
485
486
487

        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

488
489
490
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
491
        return fused_moe_make_expert_params_mapping(
492
            self,
493
494
495
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
496
            num_experts=self.config.moe_num_experts,
497
            num_redundant_experts=self.num_redundant_experts,
498
        )
499

500
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
501
502
503
504
505
506
507
508
509
510
511
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
512
        expert_params_mapping = self.get_expert_mapping()
513
        for name, loaded_weight in weights:
514
            if self.config.tie_word_embeddings and name.endswith("lm_head.weight"):
515
516
517
518
519
                continue
            # MTP will be supported soon.
            if "mtp" in name:
                continue

520
521
522
523
            if "e_score_correction_bias" in name:
                name = name.replace("moe_statics", "gate")
                loaded_weight = loaded_weight.squeeze(0)

524
            for param_name, weight_name, shard_id in stacked_params_mapping:
525
526
527
528
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue

529
                if ("mlp.experts." in name) and name not in params_dict:
530
531
532
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
533
534
535
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
536
537
538
539
540
541
542
543
544
545
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
546
                is_expert_weight = False
547
548
549
550
551
552
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping

                    if weight_name not in name:
                        continue

553
554
555
556
557
558
559
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)
560
                    # Skip layers on other devices.
561
                    if is_pp_missing_parameter(name_mapped, self):
562
563
564
                        continue

                    # Skip loading extra bias for GPTQ models.
565
                    if (
566
567
                        name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
                    ) and name_mapped not in params_dict:
568
                        continue
569
570
571
572
573
574
575
576
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
577
578
                        param,
                        loaded_weight,
579
                        name_mapped,
580
581
                        shard_id=shard_id,
                        expert_id=expert_id,
582
                        return_success=True,
583
                    )
584
585
586
                    if success:
                        name = name_mapped
                        break
587
                else:
588
589
590
591
592
593
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

594
                    # Skip loading extra bias for GPTQ models.
595
596
597
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
598
599
600
601
602
603
604
605
606
607
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    param = params_dict[name]
608
609
610
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
611
612
613
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
614
615


616
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
637
638
639
        self.model = Ernie4_5_MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
640
641

        if get_pp_group().is_last_rank:
642
643
644
645
646
647
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
648
649
650
651
652
653
654
        else:
            self.lm_head = PPMissingLayer()

        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
655
656
            self.model.make_empty_intermediate_tensors
        )
657

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        self.expert_weights = []

        # Set MoE hyperparameters
        moe_layers_indices = [
            i
            for i in range(config.num_hidden_layers)
            if (
                i >= config.moe_layer_start_index
                and i <= config.moe_layer_end_index
                and (i + 1) % config.moe_layer_interval == 0
            )
        ]
        self.num_moe_layers = len(moe_layers_indices)
        self.num_expert_groups = 1

673
        self.moe_layers: list[FusedMoE] = []
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        example_moe = None
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            assert isinstance(layer, Ernie4_5_MoeDecoderLayer)
            if isinstance(layer.mlp, Ernie4_5_MoeMoE):
                example_moe = layer.mlp
                self.moe_layers.append(layer.mlp.experts)

        if example_moe is None:
            logger.warning("No Ernie4_5_MoeMoE layer found in model.layers.")
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer.mlp, Ernie4_5_MoeMoE):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

717
718
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
719
720
721

    def forward(
        self,
722
        input_ids: torch.Tensor | None,
723
        positions: torch.Tensor,
724
725
726
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
727
728
729
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
730
731
732
733
734
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
735
    ) -> torch.Tensor | None:
736
        logits = self.logits_processor(self.lm_head, hidden_states)
737
738
        return logits

739
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
740
741
        loader = AutoWeightsLoader(
            self,
742
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
743
744
745
746
747
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()