megatron.py 20.9 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
6
import os
7
import sys
Rick Ho's avatar
Rick Ho committed
8
import math
9
10
import random
from collections import OrderedDict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
import numpy as np
Rick Ho's avatar
Rick Ho committed
12
import torch
Rick Ho's avatar
Rick Ho committed
13
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
14
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
15
16
17
18
19

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel


20
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
21
    r"""
22
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
23
24
    """

Rick Ho's avatar
Rick Ho committed
25
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
26
27
28
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
29

Rick Ho's avatar
Rick Ho committed
30
    def forward(self, x):
Sengxian's avatar
Sengxian committed
31
        r"""
Rick Ho's avatar
Rick Ho committed
32
        Directly use GeLU
Sengxian's avatar
Sengxian committed
33
        """
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
39

40
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
41
    r"""
42
43
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
44
    """
45
46
47
48
49
50
51
52
53
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
54

Sengxian's avatar
Sengxian committed
55

Rick Ho's avatar
Rick Ho committed
56
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
57
    r"""
Rick Ho's avatar
Rick Ho committed
58
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
59
60
61
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
        self.bias.data = torch.tensor(bias, dtype=dtype, device=device)


Rick Ho's avatar
Rick Ho committed
76
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
77
    r"""
Rick Ho's avatar
Rick Ho committed
78
79
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
80
81
    """

Rick Ho's avatar
Rick Ho committed
82
    def __init__(self, args, group):
Sengxian's avatar
Sengxian committed
83
84
85
        assert (
            args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
            == 0
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
Sengxian's avatar
Sengxian committed
91
92
93
94
95
96
97
98
99
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
        )
Rick Ho's avatar
Rick Ho committed
100
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
101
102
103
104
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
105
106
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
107
108
109
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
110
        r"""
Rick Ho's avatar
Rick Ho committed
111
112
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
113
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
114
        """
Rick Ho's avatar
Rick Ho committed
115
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
116
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
117
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
118
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
119
120

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
121
122
123
124
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
125
126


Sengxian's avatar
Sengxian committed
127
128
129
130
131
132
133
134
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
135
136
137
138
139
140
141
142
143
144
145
146
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
147
    """
Rick Ho's avatar
Rick Ho committed
148
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
149
    from megatron import mpu
Sengxian's avatar
Sengxian committed
150

Rick Ho's avatar
Rick Ho committed
151
152
153
154
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
155
156
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
157
158
159

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
160
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
161
162
163
164
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
165
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
166
167
168
169
170
171
172
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
173
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
Rick Ho's avatar
Rick Ho committed
174
175
176
177
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
178
    r"""
Rick Ho's avatar
Rick Ho committed
179
180
181
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
182
183
    """

Rick Ho's avatar
Rick Ho committed
184
185
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
186

Rick Ho's avatar
Rick Ho committed
187
188
189
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
190
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
191
192
193
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
194
        r"""
Rick Ho's avatar
Rick Ho committed
195
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
196
        """
Rick Ho's avatar
Rick Ho committed
197
198
199
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
200
        r"""
Rick Ho's avatar
Rick Ho committed
201
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
202
        """
Rick Ho's avatar
Rick Ho committed
203
204
205
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
206
        r"""
Rick Ho's avatar
Rick Ho committed
207
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
208
        """
Rick Ho's avatar
Rick Ho committed
209
        return self.module.load_state_dict(*args, **kwargs)
210

Jiezhong Qiu's avatar
Jiezhong Qiu committed
211
212
213
214

def get_fmoe_checkpoint_name(
    checkpoints_path, iteration, release=False, data_parallel_rank=-1
):
215
    """A unified checkpoint name, allowing specifying a data parallel rank"""
216
    from megatron import mpu
217
    from megatron.checkpointing import get_checkpoint_name
Jiezhong Qiu's avatar
Jiezhong Qiu committed
218

219
220
221
222
    if data_parallel_rank == -1:
        data_parallel_rank = mpu.get_data_parallel_rank()
    if data_parallel_rank == 0:
        return get_checkpoint_name(checkpoints_path, iteration, release)
223
224

    if release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
225
        directory = "release"
226
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
227
        directory = "iter_{:07d}".format(iteration)
228
229
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        return os.path.join(
            checkpoints_path,
            directory,
            "mp_rank_{:02d}_dp_rank_{:04d}".format(
                mpu.get_tensor_model_parallel_rank(), data_parallel_rank
            ),
            "model_optim_rng.pt",
        )
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format(
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
            data_parallel_rank,
        ),
        "model_optim_rng.pt",
    )

249

Jiezhong Qiu's avatar
Jiezhong Qiu committed
250
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
251
    """Save a model checkpoint with expert parallel """
Jiezhong Qiu's avatar
Jiezhong Qiu committed
252
    # TODO: update patch
253
254
    from megatron import get_args
    from megatron import mpu
Jiezhong Qiu's avatar
Jiezhong Qiu committed
255
256
257
    from megatron import print_rank_last

    expert_dp_comm = "none"
258

259
260
261
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
262

263
264
        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return
265

266
    args = get_args()
267

268
269
270
271
    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

Jiezhong Qiu's avatar
Jiezhong Qiu committed
272
273
274
    print_rank_last(
        "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
    )
275
276
277

    # Arguments, iteration, and model.
    state_dict = {}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
278
279
280
    state_dict["model"] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0)
    )
281

Jiezhong Qiu's avatar
Jiezhong Qiu committed
282
    def extract_expert_param(state_dict, expert_dp_comm="none"):
283
284
285
286
287
288
289
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
Jiezhong Qiu's avatar
Jiezhong Qiu committed
290
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
291
292
293
                state_dict_new[k] = v.detach()
        return state_dict_new

Jiezhong Qiu's avatar
Jiezhong Qiu committed
294
    state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm)
295
296
297
298

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
299
            state_dict["optimizer"] = optimizer.state_dict()
300
            param_global_idx = 0
301
            for param_group in optimizer.optimizer.param_groups:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
302
303
304
305
                for param in param_group["params"]:
                    if not (
                        hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm
                    ):
306
307
308
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
309
310
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
Jiezhong Qiu's avatar
Jiezhong Qiu committed
311
312
313
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None
                            )
314
                        else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
315
                            state_dict["optimizer"]["state"].pop(param_global_idx)
316
                    param_global_idx += 1
317
            if args.fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
318
                state_dict["optimizer"]["optimizer"].pop("param_groups")
319
320
321
322
323
324
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
325
326
                fp32_from_fp16_params = state_dict["optimizer"]["fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
327
328
329
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
330
331
332
333
334
335
                        param_copy = (
                            param
                            if hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm
                            else None
                        )
336
                        param_group_copy.append(param_copy)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
337
338
339
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy
                    )
340
            else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
341
                state_dict["optimizer"].pop("param_groups")
342
343

    # Save.
344
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
345
346
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
Jiezhong Qiu's avatar
Jiezhong Qiu committed
347

348
349
350
351
352
353
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
354
355
356
357
358
359
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save
            ),
            flush=True,
        )
360
361
362
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
363
        with open(tracker_filename, "w") as f:
364
365
366
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
367
368


369
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
370
371
372
    """merge two state dicts, one from data parallel rank 0,
    another only contains expert states"""
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
373

374
375
376
377
378
379
380
    def merge_model(state_dict_rank0, state_dict_local):
        for k, v in state_dict_local.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                merge_model(state_dict_rank0[k], v)
            else:
                state_dict_rank0[k] = v
Jiezhong Qiu's avatar
Jiezhong Qiu committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

    merge_model(state_dict_rank0["model"], state_dict_local["model"])

    optimizer_rank0 = (
        state_dict_rank0["optimizer"]["optimizer"]
        if fp16
        else state_dict_rank0["optimizer"]
    )
    optimizer_local = (
        state_dict_local["optimizer"]["optimizer"]
        if fp16
        else state_dict_local["optimizer"]
    )

    for k, v in optimizer_local["state"].items():
        optimizer_rank0["state"][k] = v
397
398

    if fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
399
400
401
        for group_idx, param_group in enumerate(
            state_dict_local["optimizer"]["fp32_from_fp16_params"]
        ):
402
403
            for param_in_group_idx, param in enumerate(param_group):
                if param is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
404
405
406
                    state_dict_rank0["optimizer"]["fp32_from_fp16_params"][group_idx][
                        param_in_group_idx
                    ] = param
407

408
409
    return state_dict_rank0

Jiezhong Qiu's avatar
Jiezhong Qiu committed
410
411

def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
412
413
414
415
416
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
417
418
419
420
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches
Jiezhong Qiu's avatar
Jiezhong Qiu committed
421

422
423
424
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
425

426
427
428
429
430
431
432
433
434
435
436
437
        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, DistributedDataParallel):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
438
439
440
441
442
443
        print_rank_last(
            "WARNING: could not find the metadata file {} ".format(tracker_filename)
        )
        print_rank_last(
            "    will not load any checkpoints and will start from " "random"
        )
444
445
446
447
448
449
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
Jiezhong Qiu's avatar
Jiezhong Qiu committed
450
    with open(tracker_filename, "r") as f:
451
452
453
454
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
455
            release = metastring == "release"
456
            if not release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
457
458
459
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(tracker_filename)
                )
460
461
                sys.exit()

Jiezhong Qiu's avatar
Jiezhong Qiu committed
462
463
464
    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename
    )
465
466

    # Checkpoint.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
467
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0)
468
    checkpoint_name_local = get_fmoe_checkpoint_name(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
469
470
471
472
473
474
475
476
477
478
        load_dir, iteration, release, mpu.get_data_parallel_rank()
    )
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later".format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        )
    )
479
480
481
482

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
483
            state_dict = torch.load(checkpoint_name, map_location="cpu")
484
485
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
Jiezhong Qiu's avatar
Jiezhong Qiu committed
486

487
            # For backward compatibility.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
488
489
490
491
492
493
494
495
496
497
            print_rank_last(" > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
498
        except BaseException:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
499
            print_rank_last("could not load the checkpoint")
500
501
            sys.exit()
        return state_dict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
502

503
504
505
    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

506
    state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
507
508

    # set checkpoint version
Jiezhong Qiu's avatar
Jiezhong Qiu committed
509
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))
510
511
512
513
514
515

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
516
            iteration = state_dict["iteration"]
517
518
        except KeyError:
            try:  # Backward compatible with older checkpoints
Jiezhong Qiu's avatar
Jiezhong Qiu committed
519
                iteration = state_dict["total_iters"]
520
            except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
521
522
523
524
525
526
                print_rank_last(
                    "A metadata file exists but unable to load "
                    "iteration from checkpoint {}, exiting".format(
                        checkpoint_name_local
                    )
                )
527
528
529
530
531
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
Jiezhong Qiu's avatar
Jiezhong Qiu committed
532
533
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
534
        check_checkpoint_args(checkpoint_args)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
535
536
537
        args.consumed_train_samples = getattr(
            checkpoint_args, "consumed_train_samples", 0
        )
538
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
539
540
541
        args.consumed_valid_samples = getattr(
            checkpoint_args, "consumed_valid_samples", 0
        )
542
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
543
        print_rank_last("could not find arguments in the checkpoint ...")
544
545

    # Model.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
546
    model.load_state_dict(state_dict["model"])
547
548
549
550
551

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
552
                optimizer.load_state_dict(state_dict["optimizer"])
553
            if lr_scheduler is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
554
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
555
        except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
556
557
558
559
560
561
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-optim or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
562
563
564
565
566
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
567
568
569
570
571
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
572
        except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
573
574
575
576
577
578
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
579
580
581
            sys.exit()

    torch.distributed.barrier()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
582
583
584
585
586
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}".format(
            args.load, iteration
        )
    )
587
588

    return iteration