moe.py 4.31 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
from typing import TYPE_CHECKING, Union
chenych's avatar
chenych committed
16
17

from transformers.integrations import is_deepspeed_zero3_enabled
luopl's avatar
luopl committed
18
19

from ...extras.misc import check_version
chenych's avatar
chenych committed
20
21
22


if TYPE_CHECKING:
chenych's avatar
chenych committed
23
    from torch import nn
chenych's avatar
chenych committed
24
25
26
27
28
    from transformers import PretrainedConfig, PreTrainedModel

    from ...hparams import ModelArguments


chenych's avatar
chenych committed
29
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
luopl's avatar
luopl committed
30
    check_version("deepspeed>=0.13.0")
chenych's avatar
chenych committed
31
32
33
34
35
36
    from deepspeed.utils import set_z3_leaf_modules  # type: ignore

    set_z3_leaf_modules(model, leaf_modules)


def add_z3_leaf_module(model: "PreTrainedModel") -> None:
chenych's avatar
chenych committed
37
    r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
chenych's avatar
chenych committed
38
39
40
    if not is_deepspeed_zero3_enabled():
        return

luopl's avatar
luopl committed
41
42
    model_type = getattr(model.config, "model_type", None)
    if model_type == "dbrx":
chenych's avatar
chenych committed
43
44
45
46
        from transformers.models.dbrx.modeling_dbrx import DbrxFFN

        _set_z3_leaf_modules(model, [DbrxFFN])

chenych's avatar
chenych committed
47
48
49
50
51
52
53
54
55
56
57
58
59
    if model_type == "deepseek_v2":
        # deepseek v2 uses custom code
        _set_z3_leaf_modules(model, ["DeepseekV2MoE"])

    if model_type == "deepseek_v3" or model_type == "kimi_vl":
        # deepseek v3 and kimi vl use custom code
        _set_z3_leaf_modules(model, ["DeepseekV3MoE"])

    if model_type == "granitemoe":
        from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE

        _set_z3_leaf_modules(model, [GraniteMoeMoE])

luopl's avatar
luopl committed
60
    if model_type == "jamba":
chenych's avatar
chenych committed
61
62
63
64
        from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock

        _set_z3_leaf_modules(model, [JambaSparseMoeBlock])

luopl's avatar
luopl committed
65
    if model_type == "jetmoe":
chenych's avatar
chenych committed
66
67
68
69
        from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE

        _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])

chenych's avatar
chenych committed
70
71
    if model_type == "llama4":
        from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
chenych's avatar
chenych committed
72

chenych's avatar
chenych committed
73
        _set_z3_leaf_modules(model, [Llama4TextMoe])
chenych's avatar
chenych committed
74

luopl's avatar
luopl committed
75
    if model_type == "mixtral":
chenych's avatar
chenych committed
76
77
78
79
        from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

        _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

chenych's avatar
chenych committed
80
81
82
83
84
85
86
87
88
89
    if model_type == "olmoe":
        from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock

        _set_z3_leaf_modules(model, [OlmoeSparseMoeBlock])

    if model_type == "phimoe":
        from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock

        _set_z3_leaf_modules(model, [PhimoeSparseMoeBlock])

chenych's avatar
chenych committed
90
    if model_type == "qwen2_moe":
chenych's avatar
chenych committed
91
92
93
94
        from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock

        _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])

chenych's avatar
chenych committed
95
96
97
98
99
    if model_type == "qwen3_moe":
        from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock

        _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])

chenych's avatar
chenych committed
100
101

def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
luopl's avatar
luopl committed
102
    model_type = getattr(config, "model_type", None)
chenych's avatar
chenych committed
103
    if model_args.moe_aux_loss_coef is not None:
chenych's avatar
chenych committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        if model_type in [
            "dbrx",
            "granitemoe",
            "jamba",
            "jetmoe",
            "llama4",
            "mixtral",
            "olmoe",
            "phimoe",
            "qwen2_moe",
            "qwen3_moe",
        ]:
            setattr(config, "output_router_logits", is_trainable)

        if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
chenych's avatar
chenych committed
119
120
            setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)

luopl's avatar
luopl committed
121
        elif model_type == "deepseek":
chenych's avatar
chenych committed
122
123
            setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)

luopl's avatar
luopl committed
124
        elif model_type == "jetmoe":
chenych's avatar
chenych committed
125
            setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)