checkpointing.py 25.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9
10
11

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np

import torch

12
13
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
14
15
16
17
from .global_vars import get_args
from .utils import (unwrap_model,
                    print_rank_0)

18

Vijay Korthikanti's avatar
Vijay Korthikanti committed
19
20
21
22
_CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
Jared Casper's avatar
Jared Casper committed
23
24
25
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
27
28
29
30
    _CHECKPOINT_VERSION = value

def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION
31
32
33

def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
34
    arguments and the one retrieved from checkpoint."""
35
36
    args = get_args()

37
38
39
40
41
    def _compare(arg_name, old_arg_name=None):
        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name)
42
43
44
45
46
47
48
49
50
        args_value = getattr(args, arg_name)
        error_message = '{} value from checkpoint ({}) is not equal to the ' \
                        'input argument value ({}).'.format(
                            arg_name, checkpoint_value, args_value)
        assert checkpoint_value == args_value, error_message

    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
    if args.vocab_file:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
        _compare('max_position_embeddings')
53
54
55
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
57
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
58
59
60
61
62
63
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
    if get_checkpoint_version() >= 3.0:
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')
64
65
66
67
68
69
70
71

def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


72
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
Jared Casper's avatar
Jared Casper committed
73
74
                        pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
    """Determine the directory name for this rank's checkpoint."""
75
76
77
78
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
79

80
    # Use both the tensor and pipeline MP rank.
Jared Casper's avatar
Jared Casper committed
81
    if pipeline_parallel is None:
82
        pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
83
    if tensor_rank is None:
84
        tensor_rank = mpu.get_tensor_model_parallel_rank()
85
    if pipeline_rank is None:
86
        pipeline_rank = mpu.get_pipeline_model_parallel_rank()
87

88
89
90
    # Use both the tensor and pipeline MP rank. If using the distributed
    # optimizer, then the optimizer's path must additionally include the
    # data parallel rank.
Jared Casper's avatar
Jared Casper committed
91
    if not pipeline_parallel:
92
        common_path = os.path.join(checkpoints_path, directory,
93
                            f'mp_rank_{tensor_rank:02d}')
94
95
    else:
        common_path = os.path.join(checkpoints_path, directory,
96
                        f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
97

98
    if use_distributed_optimizer:
99
        model_name = os.path.join(common_path, "model_rng.pt")
100
        optim_name = os.path.join(
101
            common_path + "_%03d" % mpu.get_data_parallel_rank(),
102
103
            "optim.pt")
    else:
104
        model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
105
    return model_name, optim_name
106

Jared Casper's avatar
Jared Casper committed
107
def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
108
109
110
111
112
113
114
115
116
117
    """Finds the checkpoint for rank 0 without knowing if we are using
    pipeline parallelism or not.

    Since the checkpoint naming scheme changes if pipeline parallelism
    is present, we need to look for both naming schemes if we don't
    know if the checkpoint has pipeline parallelism.

    """

    # Look for checkpoint with no pipelining
Jared Casper's avatar
Jared Casper committed
118
119
120
121
122
    filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
                                     pipeline_parallel=False,
                                     tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filenames[0]):
        return filenames
123
124

    # Look for checkpoint with pipelining
Jared Casper's avatar
Jared Casper committed
125
126
127
128
129
    filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
                                    pipeline_parallel=True,
                                    tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filenames[0]):
        return filenames
130

Jared Casper's avatar
Jared Casper committed
131
    return None, None
132
133

def get_checkpoint_tracker_filename(checkpoints_path):
134

135
136
137
138
139
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

158
    # Get the max iteration retrieved across the ranks.
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    if torch.distributed.is_initialized():
        iters_cuda = torch.cuda.LongTensor([iteration])
        torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
        max_iter = iters_cuda[0].item()

        # We should now have all the same iteration.
        # If not, print a warning and chose the maximum
        # iteration across all ranks.
        if iteration != max_iter:
            print('WARNING: on rank {} found iteration {} in the '
                  'metadata while max iteration across the ranks '
                  'is {}, replacing it with max iteration.'.format(
                      rank, iteration, max_iter), flush=True)
    else:
        # When loading a checkpoint outside of training (for example,
        # when editing it), we might not have torch distributed
        # initialized, in this case, just assume we have the latest
        max_iter = iteration
177
178
179
    return max_iter, release


180
181
def get_rng_state():
    """ collect rng state across data parallel ranks """
182
    args = get_args()
183
184
185
186
187
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
188
        'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
189
190
191

    rng_state_list = None
    if torch.distributed.is_initialized() and \
192
            mpu.get_data_parallel_world_size() > 1 and \
193
            args.data_parallel_random_init:
194
        rng_state_list = \
195
            [None for i in range(mpu.get_data_parallel_world_size())]
196
        torch.distributed.all_gather_object(
197
            rng_state_list,
198
            rng_state,
199
            group=mpu.get_data_parallel_group())
200
201
202
203
204
205
    else:
        rng_state_list = [rng_state]

    return rng_state_list


206
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
207
208
209
210
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
211
    model = unwrap_model(model)
212

Jared Casper's avatar
Jared Casper committed
213
214
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
215

216
    # Collect rng state across data parallel ranks.
217
218
    rng_state = get_rng_state()

219
220
    # Checkpoint file names.
    model_checkpoint_name, optim_checkpoint_name = \
221
222
        get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)

Lawrence McAfee's avatar
Lawrence McAfee committed
223
224
    # Collect args, model, RNG.
    model_state_dict = {}
225
    if not torch.distributed.is_initialized() \
226
       or mpu.get_data_parallel_rank() == 0:
227
228

        # Arguments, iteration, and model.
Lawrence McAfee's avatar
Lawrence McAfee committed
229
230
231
        model_state_dict['args'] = args
        model_state_dict['checkpoint_version'] = 3.0
        model_state_dict['iteration'] = iteration
232
        if len(model) == 1:
Lawrence McAfee's avatar
Lawrence McAfee committed
233
            model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
234
235
        else:
            for i in range(len(model)):
236
                mpu.set_virtual_pipeline_model_parallel_rank(i)
Lawrence McAfee's avatar
Lawrence McAfee committed
237
238
                model_state_dict['model%d' % i] = \
                    model[i].state_dict_for_save_checkpoint()
239
240
241

        # RNG states.
        if not args.no_save_rng:
Lawrence McAfee's avatar
Lawrence McAfee committed
242
            model_state_dict["rng_state"] = rng_state
243

Lawrence McAfee's avatar
Lawrence McAfee committed
244
    # Collect optimizer state. (Optimizer is saved separately from the model, due
245
    # to the conflicting data pattern when using the distributed optimizer.)
Lawrence McAfee's avatar
Lawrence McAfee committed
246
    optim_state_dict = {}
247
248
    if not args.no_save_optim \
       and (not torch.distributed.is_initialized()
249
            or mpu.get_data_parallel_rank() == 0
250
            or args.use_distributed_optimizer):
251

252
253
        # Optimizer stuff.
        if optimizer is not None:
Lawrence McAfee's avatar
Lawrence McAfee committed
254
            optim_state_dict['optimizer'] = optimizer.state_dict()
255
        if opt_param_scheduler is not None:
Lawrence McAfee's avatar
Lawrence McAfee committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            optim_state_dict['opt_param_scheduler'] = \
                opt_param_scheduler.state_dict()

    # Save.
    if args.use_distributed_optimizer:
        # Save model separate from optimizer.
        if model_state_dict:
            ensure_directory_exists(model_checkpoint_name)
            torch.save(model_state_dict, model_checkpoint_name)
        if optim_state_dict:
            ensure_directory_exists(optim_checkpoint_name)
            torch.save(optim_state_dict, optim_checkpoint_name)
    else:
        # Save model and optimizer together.
        state_dict = {**model_state_dict, **optim_state_dict}
        if state_dict: # only saves if populated (i.e., inherits conditions above)
            ensure_directory_exists(model_checkpoint_name)
            torch.save(state_dict, model_checkpoint_name)
274
275

    # Wait so everyone is done (necessary)
Jared Casper's avatar
Jared Casper committed
276
277
278
279
280
281
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

282
    # And update the latest iteration
Jared Casper's avatar
Jared Casper committed
283
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
284
285
286
287
288
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
Jared Casper's avatar
Jared Casper committed
289
290
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
291

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t
329

Mostofa Patwary's avatar
Mostofa Patwary committed
330
331
332
333
334
def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
335
336
337
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
Mostofa Patwary's avatar
Mostofa Patwary committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

Jared Casper's avatar
Jared Casper committed
360
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
361
362
363
    """ Load the base state_dict from the given directory

    If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
364
    """
365

366

367
    # Read the tracker file and set the iteration.
368
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
369

370
    # If no tracker file, return nothing
371
    if not os.path.isfile(tracker_filename):
372
373
374
375
376
        if not rank0:
            print_rank_0('WARNING: could not find the metadata file {} '.format(
                tracker_filename))
            print_rank_0('    will not load any checkpoints and will start from '
                         'random')
377
        return None, None, False
378
379
380

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
381
    iteration, release = read_metadata(tracker_filename)
382
383

    # Checkpoint.
384
    if rank0:
Jared Casper's avatar
Jared Casper committed
385
        checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
386
                                                  release)
387
    else:
Jared Casper's avatar
Jared Casper committed
388
389
        checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
                                                release)
390
391
392
393
        if release:
            print_rank_0(f' loading release checkpoint from {load_dir}')
        else:
            print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
394

395
    model_checkpoint_name, optim_checkpoint_name = checkpoint_names
396
397
398

    # Load the checkpoint.
    try:
399
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
Jared Casper's avatar
Jared Casper committed
400
        if use_distributed_optimizer:
Lawrence McAfee's avatar
Lawrence McAfee committed
401
402
403
            optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
        else:
            optim_state_dict = model_state_dict
404
    except ModuleNotFoundError:
mohammad's avatar
mohammad committed
405
        from megatron.fp16_deprecated import loss_scaler
406
        # For backward compatibility.
407
408
        if not rank0:
            print_rank_0(' > deserializing using the old code structure ...')
409
        sys.modules['fp16.loss_scaler'] = sys.modules[
mohammad's avatar
mohammad committed
410
411
412
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
413
414
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
        optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
415
        sys.modules.pop('fp16.loss_scaler', None)
mohammad's avatar
mohammad committed
416
        sys.modules.pop('megatron.fp16.loss_scaler', None)
417
    except BaseException as e:
418
        print_rank_0('could not load the checkpoint')
419
        print_rank_0(e)
420
421
        sys.exit()

422
    return model_state_dict, optim_state_dict, release
423
424

def load_args_from_checkpoint(args, load_arg='load'):
425
426
427
428
429
    """Set required arguments from the checkpoint specified in the
    arguments.

    Will overwrite arguments that have a non-None default value, but
    will leave any arguments that default to None as set.
430
431
432
433
434
435
436
437
438
439

    Returns the same args NameSpace with the new values added/updated.

    If no checkpoint is specified in args, or if the checkpoint is
    there but invalid, the arguments will not be modified

    """
    load_dir = getattr(args, load_arg)

    if load_dir is None:
Jared Casper's avatar
Jared Casper committed
440
        print_rank_0('No load directory specified, using provided arguments.')
441
442
        return args

Jared Casper's avatar
Jared Casper committed
443
444
445
446
    model_state_dict, optim_state_dict, release = \
        _load_base_checkpoint(load_dir,
                              use_distributed_optimizer=args.use_distributed_optimizer,
                              rank0=True)
447

448
449
450
    # For args we only care about model state dict
    state_dict = model_state_dict
    
451
    if not state_dict:
Jared Casper's avatar
Jared Casper committed
452
        print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
453
454
455
        return args

    if 'args' not in state_dict:
Jared Casper's avatar
Jared Casper committed
456
        print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        return args

    checkpoint_args = state_dict['args']
    checkpoint_version = state_dict.get('checkpoint_version', 0)
    args.iteration = state_dict['iteration']

    def _set_arg(arg_name, old_arg_name=None, force=False):
        if not force and getattr(args, arg_name, None) is not None:
            return

        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name, None)

        if checkpoint_value is not None:
473
            print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            setattr(args, arg_name, checkpoint_value)

    _set_arg('num_layers')
    _set_arg('hidden_size')
    _set_arg('ffn_hidden_size')
    _set_arg('seq_length')
    _set_arg('num_attention_heads')
    _set_arg('kv_channels')
    _set_arg('max_position_embeddings')
    _set_arg('tokenizer_type')
    _set_arg('padded_vocab_size')
    if checkpoint_version < 3.0:
        _set_arg('tensor_model_parallel_size',
                 'model_parallel_size')
    else:
        _set_arg('tensor_model_parallel_size', force=True)
        _set_arg('pipeline_model_parallel_size', force=True)
        _set_arg('num_layers_per_virtual_pipeline_stage')
    return args

494
495

def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
496
497
498
499
500
501
502
503
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

504
    model = unwrap_model(model)
505

Jared Casper's avatar
Jared Casper committed
506
507
508
509
510
511
512
    model_state_dict, optim_state_dict, release = \
        _load_base_checkpoint(load_dir,
                              use_distributed_optimizer=args.use_distributed_optimizer,
                              rank0=False)

    if model_state_dict is None:
        return 0
513

Vijay Korthikanti's avatar
Vijay Korthikanti committed
514
    # set checkpoint version
515
    set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
516

517
518
519
520
521
    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
522
            iteration = model_state_dict['iteration']
523
        except KeyError:
Neel Kant's avatar
Neel Kant committed
524
            try:  # Backward compatible with older checkpoints
525
                iteration = model_state_dict['total_iters']
526
527
528
529
530
531
532
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name))
                sys.exit()

    # Check arguments.
mohammad's avatar
mohammad committed
533
534
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
535
    if 'args' in model_state_dict and not args.finetune:
536
        checkpoint_args = model_state_dict['args']
537
        check_checkpoint_args(checkpoint_args)
538
539
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
mohammad's avatar
mohammad committed
540
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
541
542
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
543
544
545
546
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
547
    if len(model) == 1:
548
        model[0].load_state_dict(model_state_dict['model'], strict=strict)
549
550
    else:
        for i in range(len(model)):
551
            mpu.set_virtual_pipeline_model_parallel_rank(i)
552
            model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
553

Mostofa Patwary's avatar
Mostofa Patwary committed
554
555
556
557
    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)
558
559
560
561
562

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
563
                optimizer.load_state_dict(optim_state_dict['optimizer'])
564
            if opt_param_scheduler is not None:
565
                if 'lr_scheduler' in optim_state_dict: # backward compatbility
566
                    opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
567
                else:
568
                    opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
569
570
571
572
573
574
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()
575
576
577
    else:
        if args.fp16 and optimizer is not None:
            optimizer.reload_model_params()
578
579
580
581

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
582
            if 'rng_state' in model_state_dict:
583
                # access rng_state for data parallel rank
584
                if args.data_parallel_random_init:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
585

586
                    rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
587
                else:
588
                    rng_state = model_state_dict['rng_state'][0]
589
590
591
592
593
594
595
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
596
                tensor_parallel.get_cuda_rng_tracker().set_states(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
597
                    rng_state['rng_tracker_states'])
598
            else:  # backward compatability
599
600
601
602
                random.setstate(model_state_dict['random_rng_state'])
                np.random.set_state(model_state_dict['np_rng_state'])
                torch.set_rng_state(model_state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
603
                # Check for empty states array
604
                if not model_state_dict['rng_tracker_states']:
605
                    raise KeyError
606
                tensor_parallel.get_cuda_rng_tracker().set_states(
607
                    model_state_dict['rng_tracker_states'])
608
        except KeyError:
609
            print_rank_0('Unable to load rng state from checkpoint {}. '
610
                         'Specify --no-load-rng or --finetune to prevent '
611
                         'attempting to load the rng state, '
612
613
614
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

Jared Casper's avatar
Jared Casper committed
615
616
617
618
619
620
    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')
621
622

    return iteration
Neel Kant's avatar
Neel Kant committed
623
624


625
626
627
def load_biencoder_checkpoint(model, only_query_model=False,
        only_context_model=False, custom_load_path=None):
    """
628
    selectively load retrieval models for indexing/retrieving
629
630
    from saved checkpoints
    """
Neel Kant's avatar
Neel Kant committed
631
632
633

    args = get_args()

634
    model = unwrap_model(model)
Neel Kant's avatar
Neel Kant committed
635

636
    load_path = custom_load_path if custom_load_path is not None else args.load
Neel Kant's avatar
Neel Kant committed
637
638
639
640
641

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

642
643
    checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
                                              args.use_distributed_optimizer,
644
645
                                              release=False)

646
    if mpu.get_data_parallel_rank() == 0:
Neel Kant's avatar
Neel Kant committed
647
648
649
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

650
    state_dict = torch.load(model_checkpoint_name, map_location='cpu')
651
    ret_state_dict = state_dict['model']
Neel Kant's avatar
Neel Kant committed
652
653

    if only_query_model:
654
        ret_state_dict.pop('context_model')
Mostofa Patwary's avatar
Mostofa Patwary committed
655
    if only_context_model:
656
        ret_state_dict.pop('query_model')
Neel Kant's avatar
Neel Kant committed
657

658
659
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
Neel Kant's avatar
Neel Kant committed
660
661
    torch.distributed.barrier()

662
    if mpu.get_data_parallel_rank() == 0:
Neel Kant's avatar
Neel Kant committed
663
664
        print(' successfully loaded {}'.format(checkpoint_name))

Neel Kant's avatar
Neel Kant committed
665
    return model