layers.py 7.38 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
r"""
nn modules to replace Megatron's native ones
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook


class _FakeMegatronMLP(nn.Module):
    r"""
    A fake mlp without model parallelism for correctness testing
    """

    def __init__(self, args, _):
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)

    def forward(self, x):
        r"""
        Directly use GeLU
        """
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)


def _megatron_init_method(self, rng, sigma):
    r"""
    Init method based on N(0, sigma).
    Copied from Megatron-LM
    """
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()


def _random_init_weight(self, rng):
    r"""
    Copied from torch.nn.init.kaiming_uniform_
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
        self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)


class MegatronMLP(FMoETransformerMLP):
    r"""
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
    """

Rick Ho's avatar
Rick Ho committed
77
    def __init__(self, args, layer_idx, gate=None):
Rick Ho's avatar
Rick Ho committed
78
79
        if not args.distributed_experts:
            world_size = 1
Rick Ho's avatar
Rick Ho committed
80
            moe_group = None
Rick Ho's avatar
Rick Ho committed
81
        else:
Rick Ho's avatar
Rick Ho committed
82
83
84
85
            world_size = args.data_parallel_size
            from megatron.mpu import get_data_parallel_group
            moe_group = get_data_parallel_group()

86
        if not args.balance_strategy or args.balance_strategy == "naive":
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
            from fmoe.gates import NaiveGate
            gate = NaiveGate
        elif args.balance_strategy == "noisy":
            from fmoe.gates import NoisyGate
            gate = NoisyGate
92
93
94
95
96
97
        elif args.balance_strategy == "gshard":
            from fmoe.gates import GShardGate
            gate = GShardGate
        elif args.balance_strategy == "switch":
            from fmoe.gates import SwitchGate
            gate = SwitchGate
Rick Ho's avatar
Rick Ho committed
98
99
100
        elif args.balance_strategy == "swipe":
            from fmoe.gates import SwipeGate
            gate = SwipeGate
Rick Ho's avatar
Rick Ho committed
101
        elif gate is None:
Rick Ho's avatar
Rick Ho committed
102
            assert False, "Undefined balance strategy {}" % (args.balance_strategy)
Rick Ho's avatar
Rick Ho committed
103

Rick Ho's avatar
Rick Ho committed
104
105
106
107
108
109
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
Rick Ho's avatar
Rick Ho committed
110
            moe_group=moe_group,
Rick Ho's avatar
Rick Ho committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            expert_dp_comm="none" if args.distributed_experts else "dp",
            gate_hook=generate_megatron_gate_hook(
                layer_idx, args.num_experts * world_size
            ),
            gate=gate,
        )
        self.hidden_size = args.hidden_size
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
        self.reset_parameters()

    def reset_parameters(self):
        r"""
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
        additional numpy rng is used.
        """
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
133
134
135
136
137
138
139
        
        if type(self.experts) is nn.ModuleList:
            for expert in self.experts:
                _megatron_init_method(expert.htoh4, rng, self.sigma)
        else:
            _megatron_init_method(self.experts.htoh4, rng, self.sigma)
        
Rick Ho's avatar
Rick Ho committed
140
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
141
142
143
144
145
146
        
        if type(self.experts) is nn.ModuleList:
            for expert in self.experts:
                _megatron_init_method(expert.h4toh, rng, std)
        else:
            _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
147
148

    def forward(self, inp):
Rick Ho's avatar
Rick Ho committed
149
150
151
        from megatron import mpu
        x = super().forward(inp)
        x = mpu.reduce_from_tensor_model_parallel_region(x)
Rick Ho's avatar
Rick Ho committed
152
        return (
Rick Ho's avatar
Rick Ho committed
153
            x,
Rick Ho's avatar
Rick Ho committed
154
155
156
157
158
159
160
161
162
163
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )


def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
Rick Ho's avatar
Rick Ho committed
164
    gate=None,
zms1999's avatar
zms1999 committed
165
    megatron_version=None
Rick Ho's avatar
Rick Ho committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
):
    r"""
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    """
    from megatron import get_args

    args = get_args()
Rick Ho's avatar
Rick Ho committed
181
182
183
184
185

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Rick Ho's avatar
Rick Ho committed
186
187
188
189
190
191
192
193
194
195
196
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"

    if top_k is not None:
        args.top_k = top_k
    elif not hasattr(args, "top_k"):
        args.top_k = 2

Rick Ho's avatar
Rick Ho committed
197
    args.hidden_hidden_size = hidden_hidden_size
Rick Ho's avatar
Rick Ho committed
198

zms1999's avatar
zms1999 committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    if megatron_version == "v2.2":

        for idx, l in enumerate(model.language_model.transformer.layers):
            l.mlp = MegatronMLP(args, idx, gate=gate)

        # initialize gate hook
        num_layers = len(model.language_model.transformer.layers)
    elif megatron_version == "v2.5":
        
        for idx, l in enumerate(model.language_model.encoder.layers):
            l.mlp = MegatronMLP(args, idx, gate=gate)
        if hasattr(model.language_model, "decoder"):
            for idx, l in enumerate(model.language_model.decoder.layers):
                l.mlp = MegatronMLP(args, idx, gate=gate)

        # initialize gate hook
        num_layers = len(model.language_model.encoder.layers)
        if hasattr(model.language_model, "decoder"):
            num_layers += len(model.language_model.decoder.layers)
    else:
        assert False, f"megatron_version {megatron_version} not known."
Rick Ho's avatar
Rick Ho committed
220
221
222
223

    reset_gate_hook(num_layers)

    return model