megatron.py 10.8 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
6
import os
Rick Ho's avatar
Rick Ho committed
7
8
import math
import numpy as np
9
10
import random
from collections import OrderedDict
Rick Ho's avatar
Rick Ho committed
11
import torch
Rick Ho's avatar
Rick Ho committed
12
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
13
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel


19
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
20
    r"""
21
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
22
23
    """

Rick Ho's avatar
Rick Ho committed
24
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
25
26
27
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
28

Rick Ho's avatar
Rick Ho committed
29
    def forward(self, x):
Sengxian's avatar
Sengxian committed
30
        r"""
Rick Ho's avatar
Rick Ho committed
31
        Directly use GeLU
Sengxian's avatar
Sengxian committed
32
        """
Rick Ho's avatar
Rick Ho committed
33
34
35
36
37
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
38

39
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
40
    r"""
41
42
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
43
    """
44
45
46
47
48
49
50
51
52
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
53

Sengxian's avatar
Sengxian committed
54

Rick Ho's avatar
Rick Ho committed
55
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
56
    r"""
Rick Ho's avatar
Rick Ho committed
57
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
58
59
60
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
        self.bias.data = torch.tensor(bias, dtype=dtype, device=device)


Rick Ho's avatar
Rick Ho committed
75
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
76
    r"""
Rick Ho's avatar
Rick Ho committed
77
78
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
79
80
    """

Rick Ho's avatar
Rick Ho committed
81
    def __init__(self, args, group):
Sengxian's avatar
Sengxian committed
82
83
84
        assert (
            args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
            == 0
Rick Ho's avatar
Rick Ho committed
85
86
87
88
89
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
Sengxian's avatar
Sengxian committed
90
91
92
93
94
95
96
97
98
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
        )
Rick Ho's avatar
Rick Ho committed
99
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
100
101
102
103
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
104
105
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
106
107
108
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
109
        r"""
Rick Ho's avatar
Rick Ho committed
110
111
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
112
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
113
        """
Rick Ho's avatar
Rick Ho committed
114
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
115
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
116
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
117
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
118
119

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
120
121
122
123
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
124
125


Sengxian's avatar
Sengxian committed
126
127
128
129
130
131
132
133
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
134
135
136
137
138
139
140
141
142
143
144
145
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
146
    """
Rick Ho's avatar
Rick Ho committed
147
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
148
    from megatron import mpu
Sengxian's avatar
Sengxian committed
149

Rick Ho's avatar
Rick Ho committed
150
151
152
153
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
154
155
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
156
157
158

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
159
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
160
161
162
163
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
164
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
165
166
167
168
169
170
171
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
172
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
Rick Ho's avatar
Rick Ho committed
173
174
175
176
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
177
    r"""
Rick Ho's avatar
Rick Ho committed
178
179
180
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
181
182
    """

Rick Ho's avatar
Rick Ho committed
183
184
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
185

Rick Ho's avatar
Rick Ho committed
186
187
188
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
189
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
190
191
192
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
193
        r"""
Rick Ho's avatar
Rick Ho committed
194
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
195
        """
Rick Ho's avatar
Rick Ho committed
196
197
198
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
199
        r"""
Rick Ho's avatar
Rick Ho committed
200
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
201
        """
Rick Ho's avatar
Rick Ho committed
202
203
204
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
205
        r"""
Rick Ho's avatar
Rick Ho committed
206
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
207
        """
Rick Ho's avatar
Rick Ho committed
208
        return self.module.load_state_dict(*args, **kwargs)
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

def get_checkpoint_name(checkpoints_path, iteration,
                        release=False):
    """A unified checkpoint name."""
    from megatron import mpu

    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        return os.path.join(checkpoints_path, directory,
                            'mp_rank_{:02d}_dp_rank_{:04d}'.format(
                                mpu.get_tensor_model_parallel_rank(),
                                mpu.get_data_parallel_rank()
                                ),
                            'model_optim_rng.pt')
    return os.path.join(checkpoints_path, directory,
                        'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
                            mpu.get_tensor_model_parallel_rank(),
                            mpu.get_pipeline_model_parallel_rank(),
                            mpu.get_data_parallel_rank()
                            ),
                        'model_optim_rng.pt')

def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint with expert parallel """
    from megatron import get_args
    from megatron import mpu

    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

    if torch.distributed.get_rank() == 0:
        print('saving checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)

    data_parallel_rank = mpu.get_data_parallel_rank()

    # Arguments, iteration, and model.
    state_dict = {}
    state_dict['args'] = args
    state_dict['checkpoint_version'] = 3.0
    state_dict['iteration'] = iteration
    keep_vars = False if mpu.get_data_parallel_rank() == 0 else True
    state_dict['model'] = model.state_dict_for_save_checkpoint(keep_vars=keep_vars)

    if mpu.get_data_parallel_rank() != 0:

        def extract_expert_param(state_dict, expert_dp_comm='none'):
            state_dict_new = state_dict.__class__()
            for k, v in state_dict.items():
                # megatron uses both dict and OrderedDict in its state_dict
                if isinstance(v, OrderedDict) or isinstance(v, dict):
                    v_new = extract_expert_param(v, expert_dp_comm)
                    if len(v_new):
                        state_dict_new[k] = v_new
                elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
                    state_dict_new[k] = v.detach()
            return state_dict_new

        state_dict['model'] = extract_expert_param(state_dict['model'], 'none') 

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
            state_dict['optimizer'] = optimizer.state_dict()
        if lr_scheduler is not None:
            state_dict['lr_scheduler'] = lr_scheduler.state_dict()

    # RNG states.
    if not args.no_save_rng:
        state_dict['random_rng_state'] = random.getstate()
        state_dict['np_rng_state'] = np.random.get_state()
        state_dict['torch_rng_state'] = torch.get_rng_state()
        state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
        state_dict['rng_tracker_states'] \
            = mpu.get_cuda_rng_tracker().get_states()

    # Save.
    checkpoint_name = get_checkpoint_name(args.save, iteration)
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('  successfully saved checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()