megatron.py 6.66 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
Rick Ho's avatar
Rick Ho committed
6
7
8
import math
import numpy as np
import torch
Rick Ho's avatar
Rick Ho committed
9
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
10
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel


16
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
17
    r"""
18
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
19
20
    """

Rick Ho's avatar
Rick Ho committed
21
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
22
23
24
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
25

Rick Ho's avatar
Rick Ho committed
26
    def forward(self, x):
Sengxian's avatar
Sengxian committed
27
        r"""
Rick Ho's avatar
Rick Ho committed
28
        Directly use GeLU
Sengxian's avatar
Sengxian committed
29
        """
Rick Ho's avatar
Rick Ho committed
30
31
32
33
34
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
35

36
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
37
    r"""
38
39
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
40
    """
41
42
43
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
44
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
45
46
47
48
49

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
50

Sengxian's avatar
Sengxian committed
51

Rick Ho's avatar
Rick Ho committed
52
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
53
    r"""
Rick Ho's avatar
Rick Ho committed
54
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
55
56
57
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
58
59
60
61
62
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
63
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
64
65
66
67
68

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
Rick Ho's avatar
Rick Ho committed
69
        self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
70
71


Rick Ho's avatar
Rick Ho committed
72
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
73
    r"""
Rick Ho's avatar
Rick Ho committed
74
75
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
76
77
    """

Rick Ho's avatar
Rick Ho committed
78
    def __init__(self, args, group):
Sengxian's avatar
Sengxian committed
79
        assert (
Rick Ho's avatar
Rick Ho committed
80
81
            args.seq_length * args.micro_batch_size
            % args.tensor_model_parallel_size
Sengxian's avatar
Sengxian committed
82
            == 0
Rick Ho's avatar
Rick Ho committed
83
84
85
86
87
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
Sengxian's avatar
Sengxian committed
88
89
90
91
92
93
94
95
96
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
        )
Rick Ho's avatar
Rick Ho committed
97
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
98
99
100
101
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
102
103
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
104
105
106
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
107
        r"""
Rick Ho's avatar
Rick Ho committed
108
109
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
110
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
111
        """
Rick Ho's avatar
Rick Ho committed
112
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
113
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
114
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
115
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
116
117

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
118
119
120
121
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
122
123


Sengxian's avatar
Sengxian committed
124
125
126
127
128
129
130
131
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
132
133
134
135
136
137
138
139
140
141
142
143
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
144
    """
Rick Ho's avatar
Rick Ho committed
145
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
146
    from megatron import mpu
Sengxian's avatar
Sengxian committed
147

Rick Ho's avatar
Rick Ho committed
148
149
150
151
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
152
153
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
154
155
156

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
157
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
158
159
160
161
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
162
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
163
164
165
166
167
168
169
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
170
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
Rick Ho's avatar
Rick Ho committed
171
172
173
174
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
175
    r"""
Rick Ho's avatar
Rick Ho committed
176
177
178
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
179
180
    """

Rick Ho's avatar
Rick Ho committed
181
182
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
183

Rick Ho's avatar
Rick Ho committed
184
185
186
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
187
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
188
189
190
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
191
        r"""
Rick Ho's avatar
Rick Ho committed
192
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
193
        """
Rick Ho's avatar
Rick Ho committed
194
195
196
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
197
        r"""
Rick Ho's avatar
Rick Ho committed
198
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
199
        """
Rick Ho's avatar
Rick Ho committed
200
201
202
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
203
        r"""
Rick Ho's avatar
Rick Ho committed
204
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
205
        """
Rick Ho's avatar
Rick Ho committed
206
        return self.module.load_state_dict(*args, **kwargs)