bailing_moe.py 22.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Adapted from
# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py
# Copyright 2023 The vLLM team.
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BailingMoE model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
30
31
32
33
34
35

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from vllm.attention import Attention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
39
40
41
42
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
43
from vllm.model_executor.layers.activation import SiluAndMul
44
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
45
from vllm.model_executor.layers.layernorm import RMSNorm
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
54
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
59
60
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

61
from .interfaces import SupportsLoRA, SupportsPP
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70
71
72
73
74
75


class BailingAttention(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
76
77
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
ant-yy's avatar
ant-yy committed
78
        reduce_results: bool = True,
79
80
81
82
83
84
85
86
87
88
89
90
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.total_kv_heads = config.num_key_value_heads
        tp_size = get_tensor_model_parallel_world_size()

        assert self.total_num_heads % tp_size == 0
        assert self.total_num_heads >= self.total_kv_heads

        self.num_heads = self.total_num_heads // tp_size
91
        self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
92
        self.q_size_per_rank = self.head_dim * self.num_heads
93
        self.num_kv_heads = max(1, self.total_kv_heads // tp_size)
94
95
        self.kv_size_per_rank = self.num_kv_heads * self.head_dim
        self.scale = self.head_dim**-0.5
ant-yy's avatar
ant-yy committed
96
97
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
        self.use_rmsnorm = getattr(config, "use_rmsnorm", False)
98
99
100
101
102
103
104
105
106
107
108

        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_kv_heads,
            bias=(config.use_bias or config.use_qkv_bias),
            quant_config=quant_config,
            prefix=f"{prefix}.query_key_value",
        )

ant-yy's avatar
ant-yy committed
109
        if self.use_qk_norm:
110
111
112
113
114
115
116
117
118
119
            self.query_layernorm = (
                RMSNorm(self.head_dim, eps=config.rms_norm_eps)
                if self.use_rmsnorm
                else nn.LayerNorm(self.head_dim, eps=1e-6)
            )
            self.key_layernorm = (
                RMSNorm(self.head_dim, eps=config.rms_norm_eps)
                if self.use_rmsnorm
                else nn.LayerNorm(self.head_dim, eps=1e-6)
            )
ant-yy's avatar
ant-yy committed
120

121
122
123
124
125
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=config.use_bias,
            quant_config=quant_config,
ant-yy's avatar
ant-yy committed
126
            reduce_results=reduce_results,
127
128
129
            prefix=f"{prefix}.dense",
        )

130
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
ant-yy's avatar
ant-yy committed
131
132

        self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
133
134
135

        self.rotary_emb = get_rope(
            self.head_dim,
ant-yy's avatar
ant-yy committed
136
            rotary_dim=self.rotary_dim,
137
138
139
140
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            is_neox_style=True,
            rope_scaling=config.rope_scaling,
ant-yy's avatar
ant-yy committed
141
142
143
144
145
146
147
148
149
150
            partial_rotary_factor=self.partial_rotary_factor,
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scale,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
151
152
153
154
155
156
157
158
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
159
160
161
        q, k, v = qkv.split(
            [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1
        )
162

ant-yy's avatar
ant-yy committed
163
164
165
166
167
168
169
170
        if self.use_qk_norm:
            q = q.view(-1, self.num_heads, self.head_dim)
            k = k.view(-1, self.num_kv_heads, self.head_dim)
            q = self.query_layernorm(q)
            k = self.key_layernorm(k)
            q = q.view(-1, self.q_size_per_rank)
            k = k.view(-1, self.kv_size_per_rank)

171
172
173
174
175
176
177
178
179
180
181
182
183
        q, k = self.rotary_emb(position_ids, q, k)

        context_layer = self.attn(q, k, v)

        attn_output, _ = self.dense(context_layer)
        return attn_output


class BailingMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: PretrainedConfig,
184
185
        quant_config: QuantizationConfig | None = None,
        reduce_results: bool | None = True,
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size,
            [intermediate_size] * 2,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


class BailingMoE(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: PretrainedConfig,
218
219
        quant_config: QuantizationConfig | None = None,
        reduce_results: bool | None = True,
220
221
222
223
224
225
226
227
228
229
230
231
        prefix: str = "",
    ):
        super().__init__()

        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_expert_prob = config.norm_topk_prob
        self.hidden_size = config.hidden_size
        self.quant_config = quant_config
        self.num_shared_experts = config.num_shared_experts
ant-yy's avatar
ant-yy committed
232
233
234
        self.score_function = getattr(config, "score_function", None)
        self.n_group = getattr(config, "n_group", None)
        self.topk_group = getattr(config, "topk_group", None)
235
236
        self.use_grouped_topk = self.n_group is not None and self.topk_group is not None
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
ant-yy's avatar
ant-yy committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

        router_dtype = getattr(config, "router_dtype", None)
        if router_dtype is None:
            self.router_dtype = None
        elif router_dtype == "fp32":
            self.router_dtype = torch.float32
        else:
            self.router_dtype = torch.bfloat16

        self.gate = nn.Linear(
            self.hidden_size,
            self.num_experts,
            bias=False,
            dtype=self.router_dtype,
        )

        if getattr(config, "moe_router_enable_expert_bias", False):
            self.gate.expert_bias = nn.Parameter(
255
256
                torch.empty((config.num_experts,), dtype=torch.float32)
            )
ant-yy's avatar
ant-yy committed
257
258
259
        else:
            self.gate.expert_bias = None

260
261
262
        self.correction_bias = (
            self.gate.expert_bias.data if self.gate.expert_bias is not None else None
        )
ant-yy's avatar
ant-yy committed
263
264
265

        if self.score_function is not None:
            assert (
266
                self.score_function == "softmax" and self.correction_bias is None
ant-yy's avatar
ant-yy committed
267
            ) or (
268
269
                self.score_function == "sigmoid" and self.correction_bias is not None
            ), (
270
271
                "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"  # noqa: E501
            )
ant-yy's avatar
ant-yy committed
272
273
274
275
        else:
            # default value for scoring_func
            self.score_function = "softmax"

276
        if self.num_shared_experts > 0:
ant-yy's avatar
ant-yy committed
277
278
279
280
281
            if hasattr(config, "moe_shared_expert_intermediate_size"):
                intermediate_size = config.moe_shared_expert_intermediate_size
            else:
                intermediate_size = config.moe_intermediate_size
            intermediate_size *= config.num_shared_experts
282
283
284
285
286
            self.shared_experts = BailingMLP(
                intermediate_size=intermediate_size,
                config=config,
                quant_config=quant_config,
                reduce_results=False,
287
288
                prefix=f"{prefix}.shared_experts",
            )
289
290
291
        else:
            self.shared_experts = None

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
            num_experts=self.num_experts,
            top_k=self.top_k,
            hidden_size=self.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=self.norm_expert_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            scoring_func=self.score_function,
            e_score_correction_bias=self.gate.expert_bias,
            num_expert_group=self.n_group,
            topk_group=self.topk_group,
            use_grouped_topk=self.use_grouped_topk,
        )

309
310
311
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_size)
312

313
        # router_logits: (num_tokens, n_experts)
ant-yy's avatar
ant-yy committed
314
315
316
        router_logits = self.gate(hidden_states.to(self.router_dtype))
        router_logits = router_logits.to(hidden_states.dtype)

317
318
319
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
320

321
322
323
324
325
        if self.shared_experts is not None:
            shared_output, final_hidden_states = final_hidden_states
        else:
            shared_output = None

ant-yy's avatar
ant-yy committed
326
327
        final_hidden_states *= self.routed_scaling_factor

328
        if shared_output is not None:
329
330
331
            final_hidden_states = final_hidden_states + shared_output

        if self.tp_size > 1:
332
333
334
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
335
336
337
338
339
340
341
        return final_hidden_states.view(num_tokens, hidden_size)


class BailingMoeBlock(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
342
343
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
344
345
346
        prefix: str = "",
    ):
        super().__init__()
347
        layer_idx = int(prefix.split(".")[-1])
ant-yy's avatar
ant-yy committed
348
        self.config = config
349
350
        hidden_size = config.hidden_size
        intermediate_size = config.intermediate_size
ant-yy's avatar
ant-yy committed
351

352
        self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
353
354
355
        self.attention = BailingAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attention"
        )
ant-yy's avatar
ant-yy committed
356

357
        self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
ant-yy's avatar
ant-yy committed
358
359
360
361
362
363

        # Choose MLP class based on the number of experts and layer index
        if layer_idx < config.first_k_dense_replace:
            mlp_class = BailingMLP
        else:
            mlp_class = BailingMoE
364
365
366
        self.mlp = mlp_class(
            intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp"
        )
367
368
369
370
371

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
372
        residual: torch.Tensor | None,
373
374
375
376
377
    ) -> torch.Tensor:
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
378
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
379
380
381
382
383
384

        hidden_states = self.attention(
            hidden_states=hidden_states,
            position_ids=position_ids,
        )

385
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
386
387
388
389
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


390
@support_torch_compile
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
class BailingMoeModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.vocab_size = config.vocab_size
        self.embed_dim = config.hidden_size
406
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
407

408
409
410
        if get_pp_group().is_first_rank or (
            self.tie_word_embeddings and get_pp_group().is_last_rank
        ):
411
            self.word_embeddings = VocabParallelEmbedding(
ant-yy's avatar
ant-yy committed
412
413
414
415
416
                self.vocab_size,
                self.embed_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.word_embeddings",
            )
417
418
419
420
421
422
423
424
425
426
427
428
429
        else:
            self.word_embeddings = PPMissingLayer()

        self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: BailingMoeBlock(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
430
431
            prefix=f"{prefix}.layers",
        )
432

433
434
435
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
436
437
438
439
440
441

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

442
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
443
444
445
446
447
448
        return self.word_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
449
450
451
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
452
453
454
455
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
456
                hidden_states = self.embed_input_ids(input_ids)
457
458
459
460
461
462
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

463
        for layer in islice(self.layers, self.start_layer, self.end_layer):
464
465
466
467
468
469
470
            hidden_states, residual = layer(
                hidden_states,
                position_ids,
                residual,
            )

        if not get_pp_group().is_last_rank:
471
472
473
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
ant-yy's avatar
ant-yy committed
474
475
476
477
478
        else:
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
479
480
        return hidden_states

481
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
482
        return SharedFusedMoE.make_expert_params_mapping(
483
484
485
486
487
488
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts,
        )

489
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
490
491
492
493
494
495
496
497
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
498
        expert_params_mapping = self.get_expert_mapping()
499
        for name, loaded_weight in weights:
500
501
502
503
504
505
506
507
            if (
                hasattr(self.config, "norm_head")
                and self.config.norm_head
                and "lm_head.weight" in name
            ):
                loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)

            for param_name, weight_name, shard_id in stacked_params_mapping:
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
                if weight_name not in name:
                    continue
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name, self):
                        continue
ant-yy's avatar
ant-yy committed
535
536
                    if name not in params_dict:
                        continue
537
538
                    param = params_dict[name]
                    weight_loader = param.weight_loader
ant-yy's avatar
ant-yy committed
539
540
541
542
543
544
545
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
546
547
548
549
550
551
552
553
554
555
556
                    break
                else:
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if name not in params_dict:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
557
558
559
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
560
561
562
563
564
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


565
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()

ant-yy's avatar
ant-yy committed
582
583
        config = vllm_config.model_config.hf_config.get_text_config()
        vllm_config.model_config.hf_config = config
584
585
586
587
588
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config
        self.max_position_embeddings = config.max_position_embeddings
589
590
591
592
        self.model = BailingMoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
ant-yy's avatar
ant-yy committed
593

594
        if get_pp_group().is_last_rank:
ant-yy's avatar
ant-yy committed
595
596
597
598
599
600
601
602
603
            if self.tie_word_embeddings:
                self.lm_head = self.model.word_embeddings
            else:
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix=f"{prefix}.lm_head",
                )
604
605
606
607
608
            self.logits_processor = LogitsProcessor(config.vocab_size)
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
609
610
            self.model.make_empty_intermediate_tensors
        )
611

612
613
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
614
615
616
617
618

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
619
620
621
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
622
623
624
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
625
626
627
628
629
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
630
    ) -> torch.Tensor | None:
631
        logits = self.logits_processor(self.lm_head, hidden_states)
632
633
        return logits

634
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
635
636
        loader = AutoWeightsLoader(
            self,
ant-yy's avatar
ant-yy committed
637
            skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
638
639
        )
        return loader.load_weights(weights)
640
641
642

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()
ant-yy's avatar
ant-yy committed
643
644
645
646


class BailingMoeV2ForCausalLM(BailingMoeForCausalLM):
    pass