"vscode:/vscode.git/clone" did not exist on "47e7eb8003440e247fb65a2c34d64ea97e42006a"
megatron.py 20.9 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
6
import os
7
import sys
Rick Ho's avatar
Rick Ho committed
8
import math
9
10
import random
from collections import OrderedDict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
import numpy as np
Rick Ho's avatar
Rick Ho committed
12
import torch
Rick Ho's avatar
Rick Ho committed
13
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
14
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
15
16
17
18
19

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel


20
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
21
    r"""
22
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
23
24
    """

Rick Ho's avatar
Rick Ho committed
25
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
26
27
28
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
29

Rick Ho's avatar
Rick Ho committed
30
    def forward(self, x):
Sengxian's avatar
Sengxian committed
31
        r"""
Rick Ho's avatar
Rick Ho committed
32
        Directly use GeLU
Sengxian's avatar
Sengxian committed
33
        """
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
39

40
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
41
    r"""
42
43
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
44
    """
45
46
47
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
48
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
49
50
51
52
53

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
54

Sengxian's avatar
Sengxian committed
55

Rick Ho's avatar
Rick Ho committed
56
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
57
    r"""
Rick Ho's avatar
Rick Ho committed
58
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
59
60
61
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
67
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
68
69
70
71
72

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
Rick Ho's avatar
Rick Ho committed
73
        self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
74
75


Rick Ho's avatar
Rick Ho committed
76
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
77
    r"""
Rick Ho's avatar
Rick Ho committed
78
79
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
80
81
    """

Rick Ho's avatar
Rick Ho committed
82
    def __init__(self, args, group):
Sengxian's avatar
Sengxian committed
83
        assert (
Rick Ho's avatar
Rick Ho committed
84
85
            args.seq_length * args.micro_batch_size
            % args.tensor_model_parallel_size
Sengxian's avatar
Sengxian committed
86
            == 0
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
Sengxian's avatar
Sengxian committed
92
93
94
95
96
97
98
99
100
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
        )
Rick Ho's avatar
Rick Ho committed
101
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
102
103
104
105
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
106
107
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
108
109
110
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
111
        r"""
Rick Ho's avatar
Rick Ho committed
112
113
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
114
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
115
        """
Rick Ho's avatar
Rick Ho committed
116
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
117
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
118
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
119
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
120
121

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
122
123
124
125
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
126
127


Sengxian's avatar
Sengxian committed
128
129
130
131
132
133
134
135
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
136
137
138
139
140
141
142
143
144
145
146
147
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
148
    """
Rick Ho's avatar
Rick Ho committed
149
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
150
    from megatron import mpu
Sengxian's avatar
Sengxian committed
151

Rick Ho's avatar
Rick Ho committed
152
153
154
155
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
156
157
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
158
159
160

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
161
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
162
163
164
165
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
166
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
167
168
169
170
171
172
173
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
174
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
Rick Ho's avatar
Rick Ho committed
175
176
177
178
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
179
    r"""
Rick Ho's avatar
Rick Ho committed
180
181
182
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
183
184
    """

Rick Ho's avatar
Rick Ho committed
185
186
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
187

Rick Ho's avatar
Rick Ho committed
188
189
190
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
191
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
192
193
194
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
195
        r"""
Rick Ho's avatar
Rick Ho committed
196
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
197
        """
Rick Ho's avatar
Rick Ho committed
198
199
200
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
201
        r"""
Rick Ho's avatar
Rick Ho committed
202
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
203
        """
Rick Ho's avatar
Rick Ho committed
204
205
206
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
207
        r"""
Rick Ho's avatar
Rick Ho committed
208
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
209
        """
Rick Ho's avatar
Rick Ho committed
210
        return self.module.load_state_dict(*args, **kwargs)
211

Jiezhong Qiu's avatar
Jiezhong Qiu committed
212
213
214
215

def get_fmoe_checkpoint_name(
    checkpoints_path, iteration, release=False, data_parallel_rank=-1
):
216
    """A unified checkpoint name, allowing specifying a data parallel rank"""
217
    from megatron import mpu
218
    from megatron.checkpointing import get_checkpoint_name
Jiezhong Qiu's avatar
Jiezhong Qiu committed
219

220
221
222
223
    if data_parallel_rank == -1:
        data_parallel_rank = mpu.get_data_parallel_rank()
    if data_parallel_rank == 0:
        return get_checkpoint_name(checkpoints_path, iteration, release)
224
225

    if release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
226
        directory = "release"
227
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
228
        directory = "iter_{:07d}".format(iteration)
229
230
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        return os.path.join(
            checkpoints_path,
            directory,
            "mp_rank_{:02d}_dp_rank_{:04d}".format(
                mpu.get_tensor_model_parallel_rank(), data_parallel_rank
            ),
            "model_optim_rng.pt",
        )
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format(
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
            data_parallel_rank,
        ),
        "model_optim_rng.pt",
    )

250

Jiezhong Qiu's avatar
Jiezhong Qiu committed
251
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
252
    """Save a model checkpoint with expert parallel """
Jiezhong Qiu's avatar
Jiezhong Qiu committed
253
    # TODO: update patch
254
255
    from megatron import get_args
    from megatron import mpu
Jiezhong Qiu's avatar
Jiezhong Qiu committed
256
257
258
    from megatron import print_rank_last

    expert_dp_comm = "none"
259

260
261
262
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
263

264
265
        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return
266

267
    args = get_args()
268

269
270
271
272
    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

Jiezhong Qiu's avatar
Jiezhong Qiu committed
273
274
275
    print_rank_last(
        "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
    )
276
277
278

    # Arguments, iteration, and model.
    state_dict = {}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
279
280
281
    state_dict["model"] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0)
    )
282

Jiezhong Qiu's avatar
Jiezhong Qiu committed
283
    def extract_expert_param(state_dict, expert_dp_comm="none"):
284
285
286
287
288
289
290
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
Jiezhong Qiu's avatar
Jiezhong Qiu committed
291
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
292
293
294
                state_dict_new[k] = v.detach()
        return state_dict_new

Jiezhong Qiu's avatar
Jiezhong Qiu committed
295
    state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm)
296
297
298
299

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
300
            state_dict["optimizer"] = optimizer.state_dict()
301
            param_global_idx = 0
302
            for param_group in optimizer.optimizer.param_groups:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
303
304
305
306
                for param in param_group["params"]:
                    if not (
                        hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm
                    ):
307
308
309
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
310
311
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
Jiezhong Qiu's avatar
Jiezhong Qiu committed
312
313
314
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None
                            )
315
                        else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
316
                            state_dict["optimizer"]["state"].pop(param_global_idx)
317
                    param_global_idx += 1
318
            if args.fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
319
                state_dict["optimizer"]["optimizer"].pop("param_groups")
320
321
322
323
324
325
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
326
327
                fp32_from_fp16_params = state_dict["optimizer"]["fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
328
329
330
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
331
332
333
334
335
336
                        param_copy = (
                            param
                            if hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm
                            else None
                        )
337
                        param_group_copy.append(param_copy)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
338
339
340
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy
                    )
341
            else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
342
                state_dict["optimizer"].pop("param_groups")
343
344

    # Save.
345
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
346
347
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
Jiezhong Qiu's avatar
Jiezhong Qiu committed
348

349
350
351
352
353
354
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
355
356
357
358
359
360
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save
            ),
            flush=True,
        )
361
362
363
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
364
        with open(tracker_filename, "w") as f:
365
366
367
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
368
369


370
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
371
372
373
    """merge two state dicts, one from data parallel rank 0,
    another only contains expert states"""
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
374

375
376
377
378
379
380
381
    def merge_model(state_dict_rank0, state_dict_local):
        for k, v in state_dict_local.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                merge_model(state_dict_rank0[k], v)
            else:
                state_dict_rank0[k] = v
Jiezhong Qiu's avatar
Jiezhong Qiu committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

    merge_model(state_dict_rank0["model"], state_dict_local["model"])

    optimizer_rank0 = (
        state_dict_rank0["optimizer"]["optimizer"]
        if fp16
        else state_dict_rank0["optimizer"]
    )
    optimizer_local = (
        state_dict_local["optimizer"]["optimizer"]
        if fp16
        else state_dict_local["optimizer"]
    )

    for k, v in optimizer_local["state"].items():
        optimizer_rank0["state"][k] = v
398
399

    if fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
400
401
402
        for group_idx, param_group in enumerate(
            state_dict_local["optimizer"]["fp32_from_fp16_params"]
        ):
403
404
            for param_in_group_idx, param in enumerate(param_group):
                if param is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
405
406
407
                    state_dict_rank0["optimizer"]["fp32_from_fp16_params"][group_idx][
                        param_in_group_idx
                    ] = param
408

409
410
    return state_dict_rank0

Jiezhong Qiu's avatar
Jiezhong Qiu committed
411
412

def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
413
414
415
416
417
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
418
419
420
421
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches
Jiezhong Qiu's avatar
Jiezhong Qiu committed
422

423
424
425
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
426

427
428
429
430
431
432
433
434
435
436
437
438
        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, DistributedDataParallel):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
439
440
441
442
443
444
        print_rank_last(
            "WARNING: could not find the metadata file {} ".format(tracker_filename)
        )
        print_rank_last(
            "    will not load any checkpoints and will start from " "random"
        )
445
446
447
448
449
450
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
Jiezhong Qiu's avatar
Jiezhong Qiu committed
451
    with open(tracker_filename, "r") as f:
452
453
454
455
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
456
            release = metastring == "release"
457
            if not release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
458
459
460
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(tracker_filename)
                )
461
462
                sys.exit()

Jiezhong Qiu's avatar
Jiezhong Qiu committed
463
464
465
    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename
    )
466
467

    # Checkpoint.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
468
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0)
469
    checkpoint_name_local = get_fmoe_checkpoint_name(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
470
471
472
473
474
475
476
477
478
479
        load_dir, iteration, release, mpu.get_data_parallel_rank()
    )
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later".format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        )
    )
480
481
482
483

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
484
            state_dict = torch.load(checkpoint_name, map_location="cpu")
485
486
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
Jiezhong Qiu's avatar
Jiezhong Qiu committed
487

488
            # For backward compatibility.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
489
490
491
492
493
494
495
496
497
498
            print_rank_last(" > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
499
        except BaseException:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
500
            print_rank_last("could not load the checkpoint")
501
502
            sys.exit()
        return state_dict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
503

504
505
506
    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

507
    state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
508
509

    # set checkpoint version
Jiezhong Qiu's avatar
Jiezhong Qiu committed
510
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))
511
512
513
514
515
516

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
517
            iteration = state_dict["iteration"]
518
519
        except KeyError:
            try:  # Backward compatible with older checkpoints
Jiezhong Qiu's avatar
Jiezhong Qiu committed
520
                iteration = state_dict["total_iters"]
521
            except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
522
523
524
525
526
527
                print_rank_last(
                    "A metadata file exists but unable to load "
                    "iteration from checkpoint {}, exiting".format(
                        checkpoint_name_local
                    )
                )
528
529
530
531
532
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
Jiezhong Qiu's avatar
Jiezhong Qiu committed
533
534
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
535
        check_checkpoint_args(checkpoint_args)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
536
537
538
        args.consumed_train_samples = getattr(
            checkpoint_args, "consumed_train_samples", 0
        )
539
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
540
541
542
        args.consumed_valid_samples = getattr(
            checkpoint_args, "consumed_valid_samples", 0
        )
543
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
544
        print_rank_last("could not find arguments in the checkpoint ...")
545
546

    # Model.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
547
    model.load_state_dict(state_dict["model"])
548
549
550
551
552

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
553
                optimizer.load_state_dict(state_dict["optimizer"])
554
            if lr_scheduler is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
555
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
556
        except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
557
558
559
560
561
562
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-optim or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
563
564
565
566
567
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
568
569
570
571
572
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
573
        except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
574
575
576
577
578
579
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
580
581
582
            sys.exit()

    torch.distributed.barrier()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
583
584
585
586
587
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}".format(
            args.load, iteration
        )
    )
588
589

    return iteration