checkpoint.py 14.2 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
r"""
Support for Megatron to enable saving parameters of different experts on
different ranks.
"""
import os
import sys
import random
from collections import OrderedDict
import numpy as np
import torch


def get_fmoe_checkpoint_name(
    checkpoints_path, iteration, release=False, data_parallel_rank=-1
):
    """A unified checkpoint name, allowing specifying a data parallel rank"""
    from megatron import mpu
    from megatron.checkpointing import get_checkpoint_name

    if data_parallel_rank == -1:
        data_parallel_rank = mpu.get_data_parallel_rank()
    if data_parallel_rank == 0:
        return get_checkpoint_name(checkpoints_path, iteration, release)

    if release:
        directory = "release"
    else:
        directory = "iter_{:07d}".format(iteration)
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        return os.path.join(
            checkpoints_path,
            directory,
            "mp_rank_{:02d}_dp_rank_{:04d}".format(
                mpu.get_tensor_model_parallel_rank(), data_parallel_rank
            ),
            "model_optim_rng.pt",
        )
    return os.path.join(
        checkpoints_path,
        directory,
        "mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format(
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
            data_parallel_rank,
        ),
        "model_optim_rng.pt",
    )


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint with expert parallel """
    # TODO: update patch
    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last

    expert_dp_comm = "none"

    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native

        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return

    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if hasattr(model, 'module'):
        model = model.module

    print_rank_last(
        "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
    )

    # Arguments, iteration, and model.
    state_dict = {}
    state_dict["model"] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0)
    )

    def extract_expert_param(state_dict, expert_dp_comm="none"):
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
            elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
                state_dict_new[k] = v.detach()
        return state_dict_new

    state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm)

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
            state_dict["optimizer"] = optimizer.state_dict()
            param_global_idx = 0
            for param_group in optimizer.optimizer.param_groups:
                for param in param_group["params"]:
                    if not (
                        hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm
                    ):
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
                            state_dict["optimizer"]["optimizer"]["state"].pop(
                                param_global_idx, None
                            )
                        else:
                            state_dict["optimizer"]["state"].pop(param_global_idx)
                    param_global_idx += 1
            if args.fp16:
                state_dict["optimizer"]["optimizer"].pop("param_groups")
                # fp32_from_fp16_params in state_dict is not a copy
                # but a reference to optimizer.fp32_from_fp16_params,
                # changing it in state_dict will change
                # optimizer.fp32_from_fp16_params as well
                # thus we create an empty fp32_from_fp16_params in state_dict
                # and only insert expert parameters.
                fp32_from_fp16_params = state_dict["optimizer"]["fp32_from_fp16_params"]
                state_dict["optimizer"]["fp32_from_fp16_params"] = []
                for param_group in fp32_from_fp16_params:
                    param_group_copy = []
                    for param in param_group:
                        param_copy = (
                            param
                            if hasattr(param, "dp_comm")
                            and param.dp_comm == expert_dp_comm
                            else None
                        )
                        param_group_copy.append(param_copy)
                    state_dict["optimizer"]["fp32_from_fp16_params"].append(
                        param_group_copy
                    )
            else:
                state_dict["optimizer"].pop("param_groups")

    # Save.
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename

    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(
            "  successfully saved checkpoint at iteration {:7d} to {}".format(
                iteration, args.save
            ),
            flush=True,
        )
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, "w") as f:
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()


def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
    """merge two state dicts, one from data parallel rank 0,
    another only contains expert states"""
    # from megatron import print_rank_last

    def merge_model(state_dict_rank0, state_dict_local):
        for k, v in state_dict_local.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                merge_model(state_dict_rank0[k], v)
            else:
                state_dict_rank0[k] = v

    merge_model(state_dict_rank0["model"], state_dict_local["model"])

    optimizer_rank0 = (
        state_dict_rank0["optimizer"]["optimizer"]
        if fp16
        else state_dict_rank0["optimizer"]
    )
    optimizer_local = (
        state_dict_local["optimizer"]["optimizer"]
        if fp16
        else state_dict_local["optimizer"]
    )

    for k, v in optimizer_local["state"].items():
        optimizer_rank0["state"][k] = v

    if fp16:
        for group_idx, param_group in enumerate(
            state_dict_local["optimizer"]["fp32_from_fp16_params"]
        ):
            for param_in_group_idx, param in enumerate(param_group):
                if param is not None:
                    state_dict_rank0["optimizer"]["fp32_from_fp16_params"][group_idx][
                        param_in_group_idx
                    ] = param

    return state_dict_rank0


def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
    from megatron.checkpointing import get_checkpoint_tracker_filename
    from megatron.checkpointing import set_checkpoint_version
    from megatron.checkpointing import check_checkpoint_args
    from megatron.checkpointing import update_num_microbatches

    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native

        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

    if hasattr(model, 'module'):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
        print_rank_last(
            "WARNING: could not find the metadata file {} ".format(tracker_filename)
        )
        print_rank_last(
            "    will not load any checkpoints and will start from " "random"
        )
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, "r") as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == "release"
            if not release:
                print_rank_last(
                    "ERROR: Invalid metadata file {}. Exiting".format(tracker_filename)
                )
                sys.exit()

    assert iteration > 0 or release, "error parsing metadata file {}".format(
        tracker_filename
    )

    # Checkpoint.
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0)
    checkpoint_name_local = get_fmoe_checkpoint_name(
        load_dir, iteration, release, mpu.get_data_parallel_rank()
    )
    print_rank_last(
        " loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later".format(
            checkpoint_name_rank0,
            mpu.get_data_parallel_rank(),
            checkpoint_name_local,
            iteration,
        )
    )

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
            state_dict = torch.load(checkpoint_name, map_location="cpu")
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler

            # For backward compatibility.
            print_rank_last(" > deserializing using the old code structure ...")
            sys.modules["fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
                "megatron.fp16_deprecated.loss_scaler"
            ]
            state_dict = torch.load(checkpoint_name, map_location="cpu")
            sys.modules.pop("fp16.loss_scaler", None)
            sys.modules.pop("megatron.fp16.loss_scaler", None)
        return state_dict

    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

    state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)

    # set checkpoint version
    set_checkpoint_version(state_dict.get("checkpoint_version", 0))

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict["iteration"]
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = state_dict["total_iters"]
            except KeyError:
                print_rank_last(
                    "A metadata file exists but unable to load "
                    "iteration from checkpoint {}, exiting".format(
                        checkpoint_name_local
                    )
                )
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(
            checkpoint_args, "consumed_train_samples", 0
        )
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
        args.consumed_valid_samples = getattr(
            checkpoint_args, "consumed_valid_samples", 0
        )
    else:
        print_rank_last("could not find arguments in the checkpoint ...")

    # Model.
    model.load_state_dict(state_dict["model"])

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(state_dict["optimizer"])
            if lr_scheduler is not None:
                lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
        except KeyError:
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-optim or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_last(
                "Unable to load optimizer from checkpoint {}. "
                "Specify --no-load-rng or --finetune to prevent "
                "attempting to load the optimizer state, "
                "exiting ...".format(checkpoint_name_local)
            )
            sys.exit()

    torch.distributed.barrier()
    print_rank_last(
        "  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}".format(
            args.load, iteration
        )
    )

    return iteration