checkpoint.py 14.2 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
Support for Megatron to enable saving parameters of different experts on
different ranks.
Sengxian's avatar
Sengxian committed
4
"""
5
import os
6
import sys
7
8
import random
from collections import OrderedDict
Rick Ho's avatar
Rick Ho committed
9
10
import numpy as np
import torch
11

Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
13
14
15

def get_fmoe_checkpoint_name(
    checkpoints_path, iteration, release=False, data_parallel_rank=-1
):
16
    """A unified checkpoint name, allowing specifying a data parallel rank"""
17
    from megatron import mpu
18
    from megatron.checkpointing import get_checkpoint_name
Jiezhong Qiu's avatar
Jiezhong Qiu committed
19

20
21
22
23
    if data_parallel_rank == -1:
        data_parallel_rank = mpu.get_data_parallel_rank()
    if data_parallel_rank == 0:
        return get_checkpoint_name(checkpoints_path, iteration, release)
24
25

    if release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
26
        directory = "release"
27
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
28
        directory = "iter_{:07d}".format(iteration)
29
30
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        return os.path.join(
            checkpoints_path,
            directory,
            "mp_rank_{:02d}_dp_rank_{:04d}".format(
                mpu.get_tensor_model_parallel_rank(), data_parallel_rank
            ),
            "model_optim_rng.pt",
        )
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format(
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
            data_parallel_rank,
        ),
        "model_optim_rng.pt",
    )

50

Jiezhong Qiu's avatar
Jiezhong Qiu committed
51
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
52
    """Save a model checkpoint with expert parallel """
Jiezhong Qiu's avatar
Jiezhong Qiu committed
53
    # TODO: update patch
54
55
    from megatron import get_args
    from megatron import mpu
Jiezhong Qiu's avatar
Jiezhong Qiu committed
56
57
58
    from megatron import print_rank_last

    expert_dp_comm = "none"
59

60
61
62
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
63

64
65
        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return
66

67
    args = get_args()
68

69
    # Only rank zero of the data parallel writes to the disk.
Rick Ho's avatar
Rick Ho committed
70
    if hasattr(model, 'module'):
71
72
        model = model.module

Jiezhong Qiu's avatar
Jiezhong Qiu committed
73
74
75
    print_rank_last(
        "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
    )
76
77
78

    # Arguments, iteration, and model.
    state_dict = {}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
79
80
81
    state_dict["model"] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0)
    )
82

Jiezhong Qiu's avatar
Jiezhong Qiu committed
83
    def extract_expert_param(state_dict, expert_dp_comm="none"):
84
85
86
87
88
89
90
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
Jiezhong Qiu's avatar
Jiezhong Qiu committed
91
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
92
93
94
                state_dict_new[k] = v.detach()
        return state_dict_new

Jiezhong Qiu's avatar
Jiezhong Qiu committed
95
    state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm)
96
97
98
99

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
100
            state_dict["optimizer"] = optimizer.state_dict()
101
            param_global_idx = 0
102
            for param_group in optimizer.optimizer.param_groups:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
103
104
105
106
                for param in param_group["params"]:
                    if not (
                        hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm
                    ):
107
108
109
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
110
111
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
Jiezhong Qiu's avatar
Jiezhong Qiu committed
112
113
114
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None
                            )
115
                        else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
116
                            state_dict["optimizer"]["state"].pop(param_global_idx)
117
                    param_global_idx += 1
118
            if args.fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
119
                state_dict["optimizer"]["optimizer"].pop("param_groups")
120
121
122
123
124
125
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
126
127
                fp32_from_fp16_params = state_dict["optimizer"]["fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
128
129
130
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
131
132
133
134
135
136
                        param_copy = (
                            param
                            if hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm
                            else None
                        )
137
                        param_group_copy.append(param_copy)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
138
139
140
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy
                    )
141
            else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
142
                state_dict["optimizer"].pop("param_groups")
143
144

    # Save.
145
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
146
147
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
Jiezhong Qiu's avatar
Jiezhong Qiu committed
148

149
150
151
152
153
154
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
155
156
157
158
159
160
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save
            ),
            flush=True,
        )
161
162
163
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
164
        with open(tracker_filename, "w") as f:
165
166
167
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
168
169


170
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
171
172
    """merge two state dicts, one from data parallel rank 0,
    another only contains expert states"""
Rick Ho's avatar
Rick Ho committed
173
    # from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
174

175
176
177
178
179
180
181
    def merge_model(state_dict_rank0, state_dict_local):
        for k, v in state_dict_local.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                merge_model(state_dict_rank0[k], v)
            else:
                state_dict_rank0[k] = v
Jiezhong Qiu's avatar
Jiezhong Qiu committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    merge_model(state_dict_rank0["model"], state_dict_local["model"])

    optimizer_rank0 = (
        state_dict_rank0["optimizer"]["optimizer"]
        if fp16
        else state_dict_rank0["optimizer"]
    )
    optimizer_local = (
        state_dict_local["optimizer"]["optimizer"]
        if fp16
        else state_dict_local["optimizer"]
    )

    for k, v in optimizer_local["state"].items():
        optimizer_rank0["state"][k] = v
198
199

    if fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
200
201
202
        for group_idx, param_group in enumerate(
            state_dict_local["optimizer"]["fp32_from_fp16_params"]
        ):
203
204
            for param_in_group_idx, param in enumerate(param_group):
                if param is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
205
206
207
                    state_dict_rank0["optimizer"]["fp32_from_fp16_params"][group_idx][
                        param_in_group_idx
                    ] = param
208

209
210
    return state_dict_rank0

Jiezhong Qiu's avatar
Jiezhong Qiu committed
211
212

def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
213
214
215
216
217
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
218
219
220
221
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches
Jiezhong Qiu's avatar
Jiezhong Qiu committed
222

223
224
225
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
226

227
228
229
230
231
        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

Rick Ho's avatar
Rick Ho committed
232
    if hasattr(model, 'module'):
233
234
235
236
237
238
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
239
240
241
242
243
244
        print_rank_last(
            "WARNING: could not find the metadata file {} ".format(tracker_filename)
        )
        print_rank_last(
            "    will not load any checkpoints and will start from " "random"
        )
245
246
247
248
249
250
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
Jiezhong Qiu's avatar
Jiezhong Qiu committed
251
    with open(tracker_filename, "r") as f:
252
253
254
255
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
256
            release = metastring == "release"
257
            if not release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
258
259
260
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(tracker_filename)
                )
261
262
                sys.exit()

Jiezhong Qiu's avatar
Jiezhong Qiu committed
263
264
265
    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename
    )
266
267

    # Checkpoint.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
268
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0)
269
    checkpoint_name_local = get_fmoe_checkpoint_name(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
270
271
272
273
274
275
276
277
278
279
        load_dir, iteration, release, mpu.get_data_parallel_rank()
    )
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later".format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        )
    )
280
281
282
283

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
284
            state_dict = torch.load(checkpoint_name, map_location="cpu")
285
286
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
Jiezhong Qiu's avatar
Jiezhong Qiu committed
287

288
            # For backward compatibility.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
289
290
291
292
293
294
295
296
297
298
            print_rank_last(" > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
299
        return state_dict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
300

301
302
303
    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

304
    state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
305
306

    # set checkpoint version
Jiezhong Qiu's avatar
Jiezhong Qiu committed
307
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))
308
309
310
311
312
313

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
314
            iteration = state_dict["iteration"]
315
316
        except KeyError:
            try:  # Backward compatible with older checkpoints
Jiezhong Qiu's avatar
Jiezhong Qiu committed
317
                iteration = state_dict["total_iters"]
318
            except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
319
320
321
322
323
324
                print_rank_last(
                    "A metadata file exists but unable to load "
                    "iteration from checkpoint {}, exiting".format(
                        checkpoint_name_local
                    )
                )
325
326
327
328
329
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
Jiezhong Qiu's avatar
Jiezhong Qiu committed
330
331
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
332
        check_checkpoint_args(checkpoint_args)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
333
334
335
        args.consumed_train_samples = getattr(
            checkpoint_args, "consumed_train_samples", 0
        )
336
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
337
338
339
        args.consumed_valid_samples = getattr(
            checkpoint_args, "consumed_valid_samples", 0
        )
340
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
341
        print_rank_last("could not find arguments in the checkpoint ...")
342
343

    # Model.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
344
    model.load_state_dict(state_dict["model"])
345
346
347
348
349

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
350
                optimizer.load_state_dict(state_dict["optimizer"])
351
            if lr_scheduler is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
352
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
353
        except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
354
355
356
357
358
359
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-optim or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
360
361
362
363
364
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
365
366
367
368
369
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
370
        except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
371
372
373
374
375
376
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
377
378
379
            sys.exit()

    torch.distributed.barrier()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
380
381
382
383
384
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}".format(
            args.load, iteration
        )
    )
385
386

    return iteration