checkpointing.py 21.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np

import torch

25
26
27
from megatron import (get_args,
                      mpu,
                      print_rank_0,
28
29
                      update_num_microbatches,
                      utils)
30

Vijay Korthikanti's avatar
Vijay Korthikanti committed
31
32
33
34
_CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
Jared Casper's avatar
Jared Casper committed
35
36
37
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
39
40
41
42
    _CHECKPOINT_VERSION = value

def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION
43
44
45

def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
46
    arguments and the one retrieved from checkpoint."""
47
48
    args = get_args()

49
50
51
52
53
    def _compare(arg_name, old_arg_name=None):
        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name)
54
55
56
57
58
59
60
61
62
        args_value = getattr(args, arg_name)
        error_message = '{} value from checkpoint ({}) is not equal to the ' \
                        'input argument value ({}).'.format(
                            arg_name, checkpoint_value, args_value)
        assert checkpoint_value == args_value, error_message

    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
    if args.vocab_file:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        _compare('max_position_embeddings')
65
66
67
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
69
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
70
71
72
73
74
75
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
    if get_checkpoint_version() >= 3.0:
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')
76
77
78
79
80
81
82
83

def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


84
85
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
                         release=False):
86
87
88
89
90
    """A unified checkpoint name."""
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
91
92
93
    # Use both the tensor and pipeline MP rank. If using the distributed
    # optimizer, then the optimizer's path must additionally include the
    # data parallel rank.
94
95
96
97
98
99
100
101
102
103
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        common_path = os.path.join(checkpoints_path, directory,
                                   'mp_rank_{:02d}'.format(
                                       mpu.get_tensor_model_parallel_rank()))
    else:
        common_path = os.path.join(checkpoints_path, directory,
                                   'mp_rank_{:02d}_{:03d}'.format(
                                       mpu.get_tensor_model_parallel_rank(),
                                       mpu.get_pipeline_model_parallel_rank()))

104
    if use_distributed_optimizer:
105
        model_name = os.path.join(common_path, "model_rng.pt")
106
107
108
109
        optim_name = os.path.join(
            common_path + "_%03d" % mpu.get_data_parallel_rank(),
            "optim.pt")
    else:
110
        model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
111
    return model_name, optim_name
112
113
114
115
116
117
118
119


def get_checkpoint_tracker_filename(checkpoints_path):
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

138
139
140
141
    # Get the max iteration retrieved across the ranks.
    iters_cuda = torch.cuda.LongTensor([iteration])
    torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
    max_iter = iters_cuda[0].item()
142
143
144
145

    # We should now have all the same iteration.
    # If not, print a warning and chose the maximum
    # iteration across all ranks.
146
147
148
149
150
    if iteration != max_iter:
        print('WARNING: on rank {} found iteration {} in the '
              'metadata while max iteration across the ranks '
              'is {}, replacing it with max iteration.'.format(
                  rank, iteration, max_iter), flush=True)
151
152
153
    return max_iter, release


154
155
def get_rng_state():
    """ collect rng state across data parallel ranks """
156
    args = get_args()
157
158
159
160
161
162
163
164
165
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}

    rng_state_list = None
    if torch.distributed.is_initialized() and \
166
167
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
168
169
170
        rng_state_list = \
            [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.all_gather_object(
171
            rng_state_list,
172
            rng_state,
173
174
175
176
177
178
179
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]

    return rng_state_list


180
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
181
182
183
184
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
185
    model = utils.unwrap_model(model)
186

Jared Casper's avatar
Jared Casper committed
187
188
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
189

190
    # Collect rng state across data parallel ranks.
191
192
    rng_state = get_rng_state()

193
194
    # Checkpoint file names.
    model_checkpoint_name, optim_checkpoint_name = \
195
196
        get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)

Lawrence McAfee's avatar
Lawrence McAfee committed
197
198
    # Collect args, model, RNG.
    model_state_dict = {}
199
200
    if not torch.distributed.is_initialized() \
       or mpu.get_data_parallel_rank() == 0:
201
202

        # Arguments, iteration, and model.
Lawrence McAfee's avatar
Lawrence McAfee committed
203
204
205
        model_state_dict['args'] = args
        model_state_dict['checkpoint_version'] = 3.0
        model_state_dict['iteration'] = iteration
206
        if len(model) == 1:
Lawrence McAfee's avatar
Lawrence McAfee committed
207
            model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
208
209
210
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
Lawrence McAfee's avatar
Lawrence McAfee committed
211
212
                model_state_dict['model%d' % i] = \
                    model[i].state_dict_for_save_checkpoint()
213
214
215

        # RNG states.
        if not args.no_save_rng:
Lawrence McAfee's avatar
Lawrence McAfee committed
216
            model_state_dict["rng_state"] = rng_state
217

Lawrence McAfee's avatar
Lawrence McAfee committed
218
    # Collect optimizer state. (Optimizer is saved separately from the model, due
219
    # to the conflicting data pattern when using the distributed optimizer.)
Lawrence McAfee's avatar
Lawrence McAfee committed
220
    optim_state_dict = {}
221
222
223
224
225
226
227
    if not args.no_save_optim \
       and (not torch.distributed.is_initialized()
            or mpu.get_data_parallel_rank() == 0
            or args.use_distributed_optimizer):

        # Optimizer stuff.
        if optimizer is not None:
Lawrence McAfee's avatar
Lawrence McAfee committed
228
            optim_state_dict['optimizer'] = optimizer.state_dict()
229
        if opt_param_scheduler is not None:
Lawrence McAfee's avatar
Lawrence McAfee committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            optim_state_dict['opt_param_scheduler'] = \
                opt_param_scheduler.state_dict()

    # Save.
    if args.use_distributed_optimizer:
        # Save model separate from optimizer.
        if model_state_dict:
            ensure_directory_exists(model_checkpoint_name)
            torch.save(model_state_dict, model_checkpoint_name)
        if optim_state_dict:
            ensure_directory_exists(optim_checkpoint_name)
            torch.save(optim_state_dict, optim_checkpoint_name)
    else:
        # Save model and optimizer together.
        state_dict = {**model_state_dict, **optim_state_dict}
        if state_dict: # only saves if populated (i.e., inherits conditions above)
            ensure_directory_exists(model_checkpoint_name)
            torch.save(state_dict, model_checkpoint_name)
248
249

    # Wait so everyone is done (necessary)
Jared Casper's avatar
Jared Casper committed
250
251
252
253
254
255
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

256
    # And update the latest iteration
Jared Casper's avatar
Jared Casper committed
257
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
258
259
260
261
262
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
Jared Casper's avatar
Jared Casper committed
263
264
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t
303

Mostofa Patwary's avatar
Mostofa Patwary committed
304
305
306
307
308
def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
309
310
311
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
Mostofa Patwary's avatar
Mostofa Patwary committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

334
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
335
336
337
338
339
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
340
    args = get_args()
341
    load_dir = getattr(args, load_arg)
342

343
    model = utils.unwrap_model(model)
344

345
    # Read the tracker file and set the iteration.
346
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
347
348
349
350
351
352
353
354
355
356
357

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
        print_rank_0('WARNING: could not find the metadata file {} '.format(
            tracker_filename))
        print_rank_0('    will not load any checkpoints and will start from '
                     'random')
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
358
    iteration, release = read_metadata(tracker_filename)
359
360

    # Checkpoint.
361
    model_checkpoint_name, optim_checkpoint_name = \
362
363
364
        get_checkpoint_names(load_dir, iteration,
                             args.use_distributed_optimizer,
                             release)
Jared Casper's avatar
Jared Casper committed
365
    print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
366
367
368

    # Load the checkpoint.
    try:
369
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
Lawrence McAfee's avatar
Lawrence McAfee committed
370
371
372
373
        if args.use_distributed_optimizer:
            optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
        else:
            optim_state_dict = model_state_dict
374
    except ModuleNotFoundError:
mohammad's avatar
mohammad committed
375
        from megatron.fp16_deprecated import loss_scaler
376
377
378
        # For backward compatibility.
        print_rank_0(' > deserializing using the old code structure ...')
        sys.modules['fp16.loss_scaler'] = sys.modules[
mohammad's avatar
mohammad committed
379
380
381
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
382
383
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
        optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
384
        sys.modules.pop('fp16.loss_scaler', None)
mohammad's avatar
mohammad committed
385
        sys.modules.pop('megatron.fp16.loss_scaler', None)
386
    except BaseException as e:
387
        print_rank_0('could not load the checkpoint')
388
        print_rank_0(e)
389
390
        sys.exit()

391
    # Set checkpoint version.
392
    set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
393

394
395
396
397
398
    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
399
            iteration = model_state_dict['iteration']
400
        except KeyError:
Neel Kant's avatar
Neel Kant committed
401
            try:  # Backward compatible with older checkpoints
402
                iteration = model_state_dict['total_iters']
403
404
405
406
407
408
409
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name))
                sys.exit()

    # Check arguments.
mohammad's avatar
mohammad committed
410
411
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
412
413
    if 'args' in model_state_dict:
        checkpoint_args = model_state_dict['args']
414
        check_checkpoint_args(checkpoint_args)
415
416
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
mohammad's avatar
mohammad committed
417
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
418
419
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
420
421
422
423
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
424
    if len(model) == 1:
425
        model[0].load_state_dict(model_state_dict['model'], strict=strict)
426
427
428
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
429
            model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
430

Mostofa Patwary's avatar
Mostofa Patwary committed
431
432
433
434
    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)
435
436
437
438
439

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
440
                optimizer.load_state_dict(optim_state_dict['optimizer'])
441
            if opt_param_scheduler is not None:
442
                if 'lr_scheduler' in optim_state_dict: # backward compatbility
443
                    opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
444
                else:
445
                    opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
446
447
448
449
450
451
452
453
454
455
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
456
            if 'rng_state' in model_state_dict:
457
                # access rng_state for data parallel rank
458
                if args.data_parallel_random_init:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
459

460
                    rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
461
                else:
462
                    rng_state = model_state_dict['rng_state'][0]
463
464
465
466
467
468
469
470
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
471
                    rng_state['rng_tracker_states'])
472
            else:  # backward compatability
473
474
475
476
                random.setstate(model_state_dict['random_rng_state'])
                np.random.set_state(model_state_dict['np_rng_state'])
                torch.set_rng_state(model_state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
477
                # Check for empty states array
478
                if not model_state_dict['rng_tracker_states']:
479
480
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
481
                    model_state_dict['rng_tracker_states'])
482
        except KeyError:
483
            print_rank_0('Unable to load rng state from checkpoint {}. '
484
                         'Specify --no-load-rng or --finetune to prevent '
485
                         'attempting to load the rng state, '
486
487
488
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

Jared Casper's avatar
Jared Casper committed
489
490
491
492
493
494
    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')
495
496

    return iteration
Neel Kant's avatar
Neel Kant committed
497
498


499
500
501
502
503
504
def load_biencoder_checkpoint(model, only_query_model=False,
        only_context_model=False, custom_load_path=None):
    """
    selectively load retrieval models for indexing/retrieving 
    from saved checkpoints
    """
Neel Kant's avatar
Neel Kant committed
505
506
507

    args = get_args()

508
    model = utils.unwrap_model(model)
Neel Kant's avatar
Neel Kant committed
509

510
    load_path = custom_load_path if custom_load_path is not None else args.load
Neel Kant's avatar
Neel Kant committed
511
512
513
514
515

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

516
517
518
    checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
                                              args.use_distributed_optimizer,
                                              False)
Neel Kant's avatar
Neel Kant committed
519
520
521
522
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

523
    state_dict = torch.load(model_checkpoint_name, map_location='cpu')
524
    ret_state_dict = state_dict['model']
Neel Kant's avatar
Neel Kant committed
525
526

    if only_query_model:
527
        ret_state_dict.pop('context_model')
Mostofa Patwary's avatar
Mostofa Patwary committed
528
    if only_context_model:
529
        ret_state_dict.pop('query_model')
Neel Kant's avatar
Neel Kant committed
530

531
532
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
Neel Kant's avatar
Neel Kant committed
533
534
535
536
537
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

Neel Kant's avatar
Neel Kant committed
538
    return model
539