checkpointing.py 22.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np

import torch

25
26
27
from megatron import (get_args,
                      mpu,
                      print_rank_0,
28
29
                      update_num_microbatches,
                      utils)
30

Vijay Korthikanti's avatar
Vijay Korthikanti committed
31
32
33
34
_CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
Jared Casper's avatar
Jared Casper committed
35
36
37
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
39
40
41
42
    _CHECKPOINT_VERSION = value

def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION
43
44
45

def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
46
    arguments and the one retrieved from checkpoint."""
47
48
    args = get_args()

49
50
51
52
53
    def _compare(arg_name, old_arg_name=None):
        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name)
54
55
56
57
58
59
60
61
62
        args_value = getattr(args, arg_name)
        error_message = '{} value from checkpoint ({}) is not equal to the ' \
                        'input argument value ({}).'.format(
                            arg_name, checkpoint_value, args_value)
        assert checkpoint_value == args_value, error_message

    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
    if args.vocab_file:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        _compare('max_position_embeddings')
65
66
67
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
69
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
70
71
72
73
74
75
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
    if get_checkpoint_version() >= 3.0:
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')
76
77
78
79
80
81
82
83

def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


84
85
def get_checkpoint_name(checkpoints_path, iteration, release=False,
                        pipeline_parallel_size=None, tensor_rank=None, pipeline_rank=None):
86
87
88
89
90
    """A unified checkpoint name."""
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
91
    # Use both the tensor and pipeline MP rank.
92
    if pipeline_parallel_size is None:
93
        pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
94
95
96
97
98
    if tensor_rank is None:
        tensor_rank = mpu.get_tensor_model_parallel_rank()
    if pipeline_rank is None:
        pipeline_rank = mpu.get_pipeline_model_parallel_rank()
    if pipeline_parallel_size == 1:
99
        return os.path.join(checkpoints_path, directory,
100
                            f'mp_rank_{tensor_rank:02d}',
101
                            'model_optim_rng.pt')
102
    return os.path.join(checkpoints_path, directory,
103
                        f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}',
104
105
                        'model_optim_rng.pt')

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
    """Finds the checkpoint for rank 0 without knowing if we are using
    pipeline parallelism or not.

    Since the checkpoint naming scheme changes if pipeline parallelism
    is present, we need to look for both naming schemes if we don't
    know if the checkpoint has pipeline parallelism.

    """

    # Look for checkpoint with no pipelining
    filename = get_checkpoint_name(checkpoints_path, iteration, release,
                                   pipeline_parallel_size=1,
                                   tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filename):
        return filename

    # Look for checkpoint with pipelining
    filename = get_checkpoint_name(checkpoints_path, iteration, release,
                                   pipeline_parallel_size=2,
                                   tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filename):
        return filename

    return None
131
132

def get_checkpoint_tracker_filename(checkpoints_path):
133

134
135
136
137
138
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

157
    # Get the max iteration retrieved across the ranks.
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    if torch.distributed.is_initialized():
        iters_cuda = torch.cuda.LongTensor([iteration])
        torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
        max_iter = iters_cuda[0].item()

        # We should now have all the same iteration.
        # If not, print a warning and chose the maximum
        # iteration across all ranks.
        if iteration != max_iter:
            print('WARNING: on rank {} found iteration {} in the '
                  'metadata while max iteration across the ranks '
                  'is {}, replacing it with max iteration.'.format(
                      rank, iteration, max_iter), flush=True)
    else:
        # When loading a checkpoint outside of training (for example,
        # when editing it), we might not have torch distributed
        # initialized, in this case, just assume we have the latest
        max_iter = iteration
176
177
178
    return max_iter, release


179
180
def get_rng_state():
    """ collect rng state across data parallel ranks """
181
    args = get_args()
182
183
184
185
186
187
188
189
190
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}

    rng_state_list = None
    if torch.distributed.is_initialized() and \
191
192
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
193
194
195
        rng_state_list = \
            [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.all_gather_object(
196
            rng_state_list,
197
            rng_state,
198
199
200
201
202
203
204
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]

    return rng_state_list


205
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
206
207
208
209
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
210
    model = utils.unwrap_model(model)
211

Jared Casper's avatar
Jared Casper committed
212
213
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
214

215
216
217
    # collect rng state across data parallel ranks
    rng_state = get_rng_state()

Jared Casper's avatar
Jared Casper committed
218
    if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
219
220
221
222

        # Arguments, iteration, and model.
        state_dict = {}
        state_dict['args'] = args
223
        state_dict['checkpoint_version'] = 3.0
224
        state_dict['iteration'] = iteration
225
226
227
228
229
230
        if len(model) == 1:
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
231
232
233
234
235

        # Optimizer stuff.
        if not args.no_save_optim:
            if optimizer is not None:
                state_dict['optimizer'] = optimizer.state_dict()
236
237
            if opt_param_scheduler is not None:
                state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
238
239
240

        # RNG states.
        if not args.no_save_rng:
241
            state_dict["rng_state"] = rng_state
242
243
244
245
246
247
248

        # Save.
        checkpoint_name = get_checkpoint_name(args.save, iteration)
        ensure_directory_exists(checkpoint_name)
        torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
Jared Casper's avatar
Jared Casper committed
249
250
251
252
253
254
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

255
    # And update the latest iteration
Jared Casper's avatar
Jared Casper committed
256
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
257
258
259
260
261
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
Jared Casper's avatar
Jared Casper committed
262
263
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t
302

Mostofa Patwary's avatar
Mostofa Patwary committed
303
304
305
306
307
def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
308
309
310
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
Mostofa Patwary's avatar
Mostofa Patwary committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

333
334
335
336
def _load_base_checkpoint(load_dir, rank0=False):
    """ Load the base state_dict from the given directory

    If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
337
    """
338

339

340
    # Read the tracker file and set the iteration.
341
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
342

343
    # If no tracker file, return nothing
344
    if not os.path.isfile(tracker_filename):
345
346
347
348
349
350
        if not rank0:
            print_rank_0('WARNING: could not find the metadata file {} '.format(
                tracker_filename))
            print_rank_0('    will not load any checkpoints and will start from '
                         'random')
        return None, False
351
352
353

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
354
    iteration, release = read_metadata(tracker_filename)
355
356

    # Checkpoint.
357
358
359
360
361
362
363
364
    if rank0:
        checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
    else:
        checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
        if release:
            print_rank_0(f' loading release checkpoint from {load_dir}')
        else:
            print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
365
366
367
368
369

    # Load the checkpoint.
    try:
        state_dict = torch.load(checkpoint_name, map_location='cpu')
    except ModuleNotFoundError:
mohammad's avatar
mohammad committed
370
        from megatron.fp16_deprecated import loss_scaler
371
        # For backward compatibility.
372
373
        if not rank0:
            print_rank_0(' > deserializing using the old code structure ...')
374
        sys.modules['fp16.loss_scaler'] = sys.modules[
mohammad's avatar
mohammad committed
375
376
377
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
378
379
        state_dict = torch.load(checkpoint_name, map_location='cpu')
        sys.modules.pop('fp16.loss_scaler', None)
mohammad's avatar
mohammad committed
380
        sys.modules.pop('megatron.fp16.loss_scaler', None)
381
    except BaseException as e:
382
        print_rank_0('could not load the checkpoint')
383
        print_rank_0(e)
384
385
        sys.exit()

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    return state_dict, release

def load_args_from_checkpoint(args, load_arg='load'):
    """Set any arguments that are not currently set from the checkpoint
    specified in the arguments.

    Returns the same args NameSpace with the new values added/updated.

    If no checkpoint is specified in args, or if the checkpoint is
    there but invalid, the arguments will not be modified

    """
    load_dir = getattr(args, load_arg)

    if load_dir is None:
        return args

    state_dict, release = _load_base_checkpoint(load_dir, True)

    if not state_dict:
        return args

    if 'args' not in state_dict:
        return args

    checkpoint_args = state_dict['args']
    checkpoint_version = state_dict.get('checkpoint_version', 0)
    args.iteration = state_dict['iteration']

    def _set_arg(arg_name, old_arg_name=None, force=False):
        if not force and getattr(args, arg_name, None) is not None:
            return

        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name, None)

        if checkpoint_value is not None:
Jared Casper's avatar
Jared Casper committed
425
            print(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            setattr(args, arg_name, checkpoint_value)

    _set_arg('num_layers')
    _set_arg('hidden_size')
    _set_arg('ffn_hidden_size')
    _set_arg('seq_length')
    _set_arg('num_attention_heads')
    _set_arg('kv_channels')
    _set_arg('max_position_embeddings')
    _set_arg('tokenizer_type')
    _set_arg('padded_vocab_size')
    if checkpoint_version < 3.0:
        _set_arg('tensor_model_parallel_size',
                 'model_parallel_size')
    else:
        _set_arg('tensor_model_parallel_size', force=True)
        _set_arg('pipeline_model_parallel_size', force=True)
        _set_arg('num_layers_per_virtual_pipeline_stage')
    return args

446
447

def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
448
449
450
451
452
453
454
455
456
457
458
459
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

    model = utils.unwrap_model(model)

    state_dict, release = _load_base_checkpoint(load_dir, False)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
460
461
462
    # set checkpoint version
    set_checkpoint_version(state_dict.get('checkpoint_version', 0))

463
464
465
466
467
468
469
    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict['iteration']
        except KeyError:
Neel Kant's avatar
Neel Kant committed
470
            try:  # Backward compatible with older checkpoints
471
472
473
474
475
476
477
478
                iteration = state_dict['total_iters']
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name))
                sys.exit()

    # Check arguments.
mohammad's avatar
mohammad committed
479
480
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
481
482
483
    if 'args' in state_dict:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
484
485
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
mohammad's avatar
mohammad committed
486
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
487
488
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
489
490
491
492
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
493
494
495
496
497
498
    if len(model) == 1:
        model[0].load_state_dict(state_dict['model'], strict=strict)
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
            model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
499

Mostofa Patwary's avatar
Mostofa Patwary committed
500
501
502
503
    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)
504
505
506
507
508
509

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(state_dict['optimizer'])
510
511
512
513
514
            if opt_param_scheduler is not None:
                if 'lr_scheduler' in state_dict: # backward compatbility
                    opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
                else:
                    opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
515
516
517
518
519
520
521
522
523
524
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
525
526
            if 'rng_state' in state_dict:
                # access rng_state for data parallel rank
527
                if args.data_parallel_random_init:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
528

529
530
531
                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                else:
                    rng_state = state_dict['rng_state'][0]
532
533
534
535
536
537
538
539
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
540
                    rng_state['rng_tracker_states'])
541
542
543
544
545
546
547
548
549
550
            else:  # backward compatability
                random.setstate(state_dict['random_rng_state'])
                np.random.set_state(state_dict['np_rng_state'])
                torch.set_rng_state(state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
                # Check for empty states array
                if not state_dict['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
                    state_dict['rng_tracker_states'])
551
        except KeyError:
552
            print_rank_0('Unable to load rng state from checkpoint {}. '
553
                         'Specify --no-load-rng or --finetune to prevent '
554
                         'attempting to load the rng state, '
555
556
557
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

Jared Casper's avatar
Jared Casper committed
558
559
560
561
562
563
    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')
564
565

    return iteration
Neel Kant's avatar
Neel Kant committed
566
567


568
569
570
def load_biencoder_checkpoint(model, only_query_model=False,
        only_context_model=False, custom_load_path=None):
    """
571
    selectively load retrieval models for indexing/retrieving
572
573
    from saved checkpoints
    """
Neel Kant's avatar
Neel Kant committed
574
575
576

    args = get_args()

577
    model = utils.unwrap_model(model)
Neel Kant's avatar
Neel Kant committed
578

579
    load_path = custom_load_path if custom_load_path is not None else args.load
Neel Kant's avatar
Neel Kant committed
580
581
582
583
584
585
586
587
588
589
590

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
591
    ret_state_dict = state_dict['model']
Neel Kant's avatar
Neel Kant committed
592
593

    if only_query_model:
594
        ret_state_dict.pop('context_model')
Mostofa Patwary's avatar
Mostofa Patwary committed
595
    if only_context_model:
596
        ret_state_dict.pop('query_model')
Neel Kant's avatar
Neel Kant committed
597

598
599
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
Neel Kant's avatar
Neel Kant committed
600
601
602
603
604
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

Neel Kant's avatar
Neel Kant committed
605
    return model