checkpoint.py 15 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
Support for Megatron to enable saving parameters of different experts on
different ranks.
Sengxian's avatar
Sengxian committed
4
"""
5
import os
6
import sys
7
8
import random
from collections import OrderedDict
Rick Ho's avatar
Rick Ho committed
9
10
import numpy as np
import torch
11

Jiezhong Qiu's avatar
Jiezhong Qiu committed
12
13
14
15

def get_fmoe_checkpoint_name(
    checkpoints_path, iteration, release=False, data_parallel_rank=-1
):
16
    """A unified checkpoint name, allowing specifying a data parallel rank"""
17
    from megatron import mpu
18
    from megatron.checkpointing import get_checkpoint_name
Jiezhong Qiu's avatar
Jiezhong Qiu committed
19

20
21
22
23
    if data_parallel_rank == -1:
        data_parallel_rank = mpu.get_data_parallel_rank()
    if data_parallel_rank == 0:
        return get_checkpoint_name(checkpoints_path, iteration, release)
24
25

    if release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
26
        directory = "release"
27
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
28
        directory = "iter_{:07d}".format(iteration)
29
30
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        return os.path.join(
            checkpoints_path,
            directory,
            "mp_rank_{:02d}_dp_rank_{:04d}".format(
                mpu.get_tensor_model_parallel_rank(), data_parallel_rank
            ),
            "model_optim_rng.pt",
        )
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format(
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
            data_parallel_rank,
        ),
        "model_optim_rng.pt",
    )

50

Jiezhong Qiu's avatar
Jiezhong Qiu committed
51
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
52
    """Save a model checkpoint with expert parallel """
Jiezhong Qiu's avatar
Jiezhong Qiu committed
53
    # TODO: update patch
54
55
    from megatron import get_args
    from megatron import mpu
Jiezhong Qiu's avatar
Jiezhong Qiu committed
56
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
57
    from megatron import utils
Jiezhong Qiu's avatar
Jiezhong Qiu committed
58
59

    expert_dp_comm = "none"
60

61
62
63
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
64

65
66
        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return
67

68
    args = get_args()
69

70
    # Only rank zero of the data parallel writes to the disk.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
71
72
73
74
75
76
77
    try:
        model = utils.unwrap_model(model)
    except AttributeError:
        # fallback to the old way of unwrapping a model
        if hasattr(model, 'module'):
            model = model.module
        model = [model,]
78

Jiezhong Qiu's avatar
Jiezhong Qiu committed
79
80
81
    print_rank_last(
        "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
    )
82
83
84

    # Arguments, iteration, and model.
    state_dict = {}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
85
86
    assert len(model) == 1, "FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
    state_dict["model"] = model[0].state_dict_for_save_checkpoint(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
87
88
        keep_vars=(mpu.get_data_parallel_rank() > 0)
    )
89

Jiezhong Qiu's avatar
Jiezhong Qiu committed
90
    def extract_expert_param(state_dict, expert_dp_comm="none"):
91
92
93
94
95
96
97
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
Jiezhong Qiu's avatar
Jiezhong Qiu committed
98
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
99
100
101
                state_dict_new[k] = v.detach()
        return state_dict_new

Jiezhong Qiu's avatar
Jiezhong Qiu committed
102
    state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm)
103
104
105
106

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
107
            state_dict["optimizer"] = optimizer.state_dict()
108
            param_global_idx = 0
109
            for param_group in optimizer.optimizer.param_groups:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
110
111
112
113
                for param in param_group["params"]:
                    if not (
                        hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm
                    ):
114
115
116
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
117
118
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
Jiezhong Qiu's avatar
Jiezhong Qiu committed
119
120
121
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None
                            )
122
                        else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
123
                            state_dict["optimizer"]["state"].pop(param_global_idx)
124
                    param_global_idx += 1
125
            if args.fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
126
                state_dict["optimizer"]["optimizer"].pop("param_groups")
127
128
129
130
131
132
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
133
134
                fp32_from_fp16_params = state_dict["optimizer"]["fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
135
136
137
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
138
139
140
141
142
143
                        param_copy = (
                            param
                            if hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm
                            else None
                        )
144
                        param_group_copy.append(param_copy)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
145
146
147
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy
                    )
148
            else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
                state_dict["optimizer"].pop("param_groups")
150
151

    # Save.
152
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
153
154
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
Jiezhong Qiu's avatar
Jiezhong Qiu committed
155

156
157
158
159
160
161
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
162
163
164
165
166
167
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save
            ),
            flush=True,
        )
168
169
170
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
171
        with open(tracker_filename, "w") as f:
172
173
174
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
175
176


177
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
178
179
    """merge two state dicts, one from data parallel rank 0,
    another only contains expert states"""
Rick Ho's avatar
Rick Ho committed
180
    # from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
181

182
183
184
185
186
187
188
    def merge_model(state_dict_rank0, state_dict_local):
        for k, v in state_dict_local.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                merge_model(state_dict_rank0[k], v)
            else:
                state_dict_rank0[k] = v
Jiezhong Qiu's avatar
Jiezhong Qiu committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    merge_model(state_dict_rank0["model"], state_dict_local["model"])

    optimizer_rank0 = (
        state_dict_rank0["optimizer"]["optimizer"]
        if fp16
        else state_dict_rank0["optimizer"]
    )
    optimizer_local = (
        state_dict_local["optimizer"]["optimizer"]
        if fp16
        else state_dict_local["optimizer"]
    )

    for k, v in optimizer_local["state"].items():
        optimizer_rank0["state"][k] = v
205
206

    if fp16:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
207
208
209
        for group_idx, param_group in enumerate(
            state_dict_local["optimizer"]["fp32_from_fp16_params"]
        ):
210
211
            for param_in_group_idx, param in enumerate(param_group):
                if param is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
212
213
214
                    state_dict_rank0["optimizer"]["fp32_from_fp16_params"][group_idx][
                        param_in_group_idx
                    ] = param
215

216
217
    return state_dict_rank0

Jiezhong Qiu's avatar
Jiezhong Qiu committed
218
219

def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
220
221
222
223
224
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
Jiezhong Qiu's avatar
Jiezhong Qiu committed
225
    from megatron import utils
Jiezhong Qiu's avatar
Jiezhong Qiu committed
226
227
228
229
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches
Jiezhong Qiu's avatar
Jiezhong Qiu committed
230

231
232
233
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native
Jiezhong Qiu's avatar
Jiezhong Qiu committed
234

235
236
237
238
239
        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

Jiezhong Qiu's avatar
Jiezhong Qiu committed
240
241
242
243
244
245
246
247
248
    # Only rank zero of the data parallel writes to the disk.
    try:
        model = utils.unwrap_model(model)
    except AttributeError:
        # fallback to the old way of unwrapping a model
        if hasattr(model, 'module'):
            model = model.module
        model = [model,]

249
250
251
252
253
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
254
255
256
257
258
259
        print_rank_last(
            "WARNING: could not find the metadata file {} ".format(tracker_filename)
        )
        print_rank_last(
            "    will not load any checkpoints and will start from " "random"
        )
260
261
262
263
264
265
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
Jiezhong Qiu's avatar
Jiezhong Qiu committed
266
    with open(tracker_filename, "r") as f:
267
268
269
270
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
271
            release = metastring == "release"
272
            if not release:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
273
274
275
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(tracker_filename)
                )
276
277
                sys.exit()

Jiezhong Qiu's avatar
Jiezhong Qiu committed
278
279
280
    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename
    )
281
282

    # Checkpoint.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
283
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0)
284
    checkpoint_name_local = get_fmoe_checkpoint_name(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
285
286
287
288
289
290
291
292
293
294
        load_dir, iteration, release, mpu.get_data_parallel_rank()
    )
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later".format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        )
    )
295
296
297
298

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
299
            state_dict = torch.load(checkpoint_name, map_location="cpu")
300
301
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
Jiezhong Qiu's avatar
Jiezhong Qiu committed
302

303
            # For backward compatibility.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
304
305
306
307
308
309
310
311
312
313
            print_rank_last(" > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
314
        return state_dict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
315

316
317
318
    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

319
    state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
320
321

    # set checkpoint version
Jiezhong Qiu's avatar
Jiezhong Qiu committed
322
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))
323
324
325
326
327
328

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
329
            iteration = state_dict["iteration"]
330
331
        except KeyError:
            try:  # Backward compatible with older checkpoints
Jiezhong Qiu's avatar
Jiezhong Qiu committed
332
                iteration = state_dict["total_iters"]
333
            except KeyError:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
334
335
336
337
338
339
                print_rank_last(
                    "A metadata file exists but unable to load "
                    "iteration from checkpoint {}, exiting".format(
                        checkpoint_name_local
                    )
                )
340
341
342
343
344
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
Jiezhong Qiu's avatar
Jiezhong Qiu committed
345
346
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
347
        check_checkpoint_args(checkpoint_args)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
348
349
350
        args.consumed_train_samples = getattr(
            checkpoint_args, "consumed_train_samples", 0
        )
351
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
352
353
354
        args.consumed_valid_samples = getattr(
            checkpoint_args, "consumed_valid_samples", 0
        )
355
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
356
        print_rank_last("could not find arguments in the checkpoint ...")
357
358

    # Model.
Jiezhong Qiu's avatar
Jiezhong Qiu committed
359
360
    assert len(model) == 1, "FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
    model[0].load_state_dict(state_dict["model"])
361
362
363
364
365

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
366
                optimizer.load_state_dict(state_dict["optimizer"])
367
            if lr_scheduler is not None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
368
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
Jiezhong Qiu's avatar
Jiezhong Qiu committed
369
        except KeyError as e:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
370
            print_rank_last(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
371
                "FMoE is unable to load optimizer from checkpoint {}. "
Jiezhong Qiu's avatar
Jiezhong Qiu committed
372
373
374
375
                "Specify --no-load-optim or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
376
377
378
379
380
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
381
382
383
384
385
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
Jiezhong Qiu's avatar
Jiezhong Qiu committed
386
387
        except KeyError as e:
            print_rank_last(e)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
388
            print_rank_last(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
389
                "FMoE is unable to load rng state from checkpoint {}. "
Jiezhong Qiu's avatar
Jiezhong Qiu committed
390
391
392
393
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
394
395
396
            sys.exit()

    torch.distributed.barrier()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
397
398
399
400
401
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}".format(
            args.load, iteration
        )
    )
402
403

    return iteration