checkpointing.py 26.2 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np

import torch

25
26
27
28
29
30
from megatron import (mpu,
                      update_num_microbatches)
from .global_vars import get_args
from .utils import (unwrap_model,
                    print_rank_0)

31

Vijay Korthikanti's avatar
Vijay Korthikanti committed
32
33
34
35
_CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
Jared Casper's avatar
Jared Casper committed
36
37
38
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
39
40
41
42
43
    _CHECKPOINT_VERSION = value

def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION
44
45
46

def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
47
    arguments and the one retrieved from checkpoint."""
48
49
    args = get_args()

50
51
52
53
54
    def _compare(arg_name, old_arg_name=None):
        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name)
55
56
57
58
59
60
61
62
63
        args_value = getattr(args, arg_name)
        error_message = '{} value from checkpoint ({}) is not equal to the ' \
                        'input argument value ({}).'.format(
                            arg_name, checkpoint_value, args_value)
        assert checkpoint_value == args_value, error_message

    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
    if args.vocab_file:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
        _compare('max_position_embeddings')
66
67
68
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
70
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
71
72
73
74
75
76
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
    if get_checkpoint_version() >= 3.0:
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')
77
78
79
80
81
82
83
84

def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


85
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
Jared Casper's avatar
Jared Casper committed
86
87
                        pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
    """Determine the directory name for this rank's checkpoint."""
88
89
90
91
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
92

93
    # Use both the tensor and pipeline MP rank.
Jared Casper's avatar
Jared Casper committed
94
95
    if pipeline_parallel is None:
        pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
96
97
98
99
    if tensor_rank is None:
        tensor_rank = mpu.get_tensor_model_parallel_rank()
    if pipeline_rank is None:
        pipeline_rank = mpu.get_pipeline_model_parallel_rank()
100

101
102
103
    # Use both the tensor and pipeline MP rank. If using the distributed
    # optimizer, then the optimizer's path must additionally include the
    # data parallel rank.
Jared Casper's avatar
Jared Casper committed
104
    if not pipeline_parallel:
105
        common_path = os.path.join(checkpoints_path, directory,
106
                            f'mp_rank_{tensor_rank:02d}')
107
108
    else:
        common_path = os.path.join(checkpoints_path, directory,
109
                        f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
110

111
    if use_distributed_optimizer:
112
        model_name = os.path.join(common_path, "model_rng.pt")
113
114
115
116
        optim_name = os.path.join(
            common_path + "_%03d" % mpu.get_data_parallel_rank(),
            "optim.pt")
    else:
117
        model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
118
    return model_name, optim_name
119

Jared Casper's avatar
Jared Casper committed
120
def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
121
122
123
124
125
126
127
128
129
130
    """Finds the checkpoint for rank 0 without knowing if we are using
    pipeline parallelism or not.

    Since the checkpoint naming scheme changes if pipeline parallelism
    is present, we need to look for both naming schemes if we don't
    know if the checkpoint has pipeline parallelism.

    """

    # Look for checkpoint with no pipelining
Jared Casper's avatar
Jared Casper committed
131
132
133
134
135
    filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
                                     pipeline_parallel=False,
                                     tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filenames[0]):
        return filenames
136
137

    # Look for checkpoint with pipelining
Jared Casper's avatar
Jared Casper committed
138
139
140
141
142
    filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
                                    pipeline_parallel=True,
                                    tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filenames[0]):
        return filenames
143

Jared Casper's avatar
Jared Casper committed
144
    return None, None
145
146

def get_checkpoint_tracker_filename(checkpoints_path):
147

148
149
150
151
152
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

171
    # Get the max iteration retrieved across the ranks.
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    if torch.distributed.is_initialized():
        iters_cuda = torch.cuda.LongTensor([iteration])
        torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
        max_iter = iters_cuda[0].item()

        # We should now have all the same iteration.
        # If not, print a warning and chose the maximum
        # iteration across all ranks.
        if iteration != max_iter:
            print('WARNING: on rank {} found iteration {} in the '
                  'metadata while max iteration across the ranks '
                  'is {}, replacing it with max iteration.'.format(
                      rank, iteration, max_iter), flush=True)
    else:
        # When loading a checkpoint outside of training (for example,
        # when editing it), we might not have torch distributed
        # initialized, in this case, just assume we have the latest
        max_iter = iteration
190
191
192
    return max_iter, release


193
194
def get_rng_state():
    """ collect rng state across data parallel ranks """
195
    args = get_args()
196
197
198
199
200
201
202
203
204
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}

    rng_state_list = None
    if torch.distributed.is_initialized() and \
205
206
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
207
208
209
        rng_state_list = \
            [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.all_gather_object(
210
            rng_state_list,
211
            rng_state,
212
213
214
215
216
217
218
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]

    return rng_state_list


219
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
220
221
222
223
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
224
    model = unwrap_model(model)
225

Jared Casper's avatar
Jared Casper committed
226
227
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
228

229
    # Collect rng state across data parallel ranks.
230
231
    rng_state = get_rng_state()

232
233
    # Checkpoint file names.
    model_checkpoint_name, optim_checkpoint_name = \
234
235
        get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)

Lawrence McAfee's avatar
Lawrence McAfee committed
236
237
    # Collect args, model, RNG.
    model_state_dict = {}
238
239
    if not torch.distributed.is_initialized() \
       or mpu.get_data_parallel_rank() == 0:
240
241

        # Arguments, iteration, and model.
Lawrence McAfee's avatar
Lawrence McAfee committed
242
243
244
        model_state_dict['args'] = args
        model_state_dict['checkpoint_version'] = 3.0
        model_state_dict['iteration'] = iteration
245
        if len(model) == 1:
Lawrence McAfee's avatar
Lawrence McAfee committed
246
            model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
247
248
249
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
Lawrence McAfee's avatar
Lawrence McAfee committed
250
251
                model_state_dict['model%d' % i] = \
                    model[i].state_dict_for_save_checkpoint()
252
253
254

        # RNG states.
        if not args.no_save_rng:
Lawrence McAfee's avatar
Lawrence McAfee committed
255
            model_state_dict["rng_state"] = rng_state
256

Lawrence McAfee's avatar
Lawrence McAfee committed
257
    # Collect optimizer state. (Optimizer is saved separately from the model, due
258
    # to the conflicting data pattern when using the distributed optimizer.)
Lawrence McAfee's avatar
Lawrence McAfee committed
259
    optim_state_dict = {}
260
261
262
263
    if not args.no_save_optim \
       and (not torch.distributed.is_initialized()
            or mpu.get_data_parallel_rank() == 0
            or args.use_distributed_optimizer):
264

265
266
        # Optimizer stuff.
        if optimizer is not None:
Lawrence McAfee's avatar
Lawrence McAfee committed
267
            optim_state_dict['optimizer'] = optimizer.state_dict()
268
        if opt_param_scheduler is not None:
Lawrence McAfee's avatar
Lawrence McAfee committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            optim_state_dict['opt_param_scheduler'] = \
                opt_param_scheduler.state_dict()

    # Save.
    if args.use_distributed_optimizer:
        # Save model separate from optimizer.
        if model_state_dict:
            ensure_directory_exists(model_checkpoint_name)
            torch.save(model_state_dict, model_checkpoint_name)
        if optim_state_dict:
            ensure_directory_exists(optim_checkpoint_name)
            torch.save(optim_state_dict, optim_checkpoint_name)
    else:
        # Save model and optimizer together.
        state_dict = {**model_state_dict, **optim_state_dict}
        if state_dict: # only saves if populated (i.e., inherits conditions above)
            ensure_directory_exists(model_checkpoint_name)
            torch.save(state_dict, model_checkpoint_name)
287
288

    # Wait so everyone is done (necessary)
Jared Casper's avatar
Jared Casper committed
289
290
291
292
293
294
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

295
    # And update the latest iteration
Jared Casper's avatar
Jared Casper committed
296
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
297
298
299
300
301
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
Jared Casper's avatar
Jared Casper committed
302
303
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
304

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t
342

Mostofa Patwary's avatar
Mostofa Patwary committed
343
344
345
346
347
def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
348
349
350
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
Mostofa Patwary's avatar
Mostofa Patwary committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

Jared Casper's avatar
Jared Casper committed
373
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
374
375
376
    """ Load the base state_dict from the given directory

    If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
377
    """
378

379

380
    # Read the tracker file and set the iteration.
381
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
382

383
    # If no tracker file, return nothing
384
    if not os.path.isfile(tracker_filename):
385
386
387
388
389
        if not rank0:
            print_rank_0('WARNING: could not find the metadata file {} '.format(
                tracker_filename))
            print_rank_0('    will not load any checkpoints and will start from '
                         'random')
390
        return None, None, False
391
392
393

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
394
    iteration, release = read_metadata(tracker_filename)
395
396

    # Checkpoint.
397
    if rank0:
Jared Casper's avatar
Jared Casper committed
398
        checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
399
                                                  release)
400
    else:
Jared Casper's avatar
Jared Casper committed
401
402
        checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
                                                release)
403
404
405
406
        if release:
            print_rank_0(f' loading release checkpoint from {load_dir}')
        else:
            print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
407

408
    model_checkpoint_name, optim_checkpoint_name = checkpoint_names
409
410
411

    # Load the checkpoint.
    try:
412
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
Jared Casper's avatar
Jared Casper committed
413
        if use_distributed_optimizer:
Lawrence McAfee's avatar
Lawrence McAfee committed
414
415
416
            optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
        else:
            optim_state_dict = model_state_dict
417
    except ModuleNotFoundError:
mohammad's avatar
mohammad committed
418
        from megatron.fp16_deprecated import loss_scaler
419
        # For backward compatibility.
420
421
        if not rank0:
            print_rank_0(' > deserializing using the old code structure ...')
422
        sys.modules['fp16.loss_scaler'] = sys.modules[
mohammad's avatar
mohammad committed
423
424
425
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
426
427
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
        optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
428
        sys.modules.pop('fp16.loss_scaler', None)
mohammad's avatar
mohammad committed
429
        sys.modules.pop('megatron.fp16.loss_scaler', None)
430
    except BaseException as e:
431
        print_rank_0('could not load the checkpoint')
432
        print_rank_0(e)
433
434
        sys.exit()

435
    return model_state_dict, optim_state_dict, release
436
437

def load_args_from_checkpoint(args, load_arg='load'):
438
439
440
441
442
    """Set required arguments from the checkpoint specified in the
    arguments.

    Will overwrite arguments that have a non-None default value, but
    will leave any arguments that default to None as set.
443
444
445
446
447
448
449
450
451
452

    Returns the same args NameSpace with the new values added/updated.

    If no checkpoint is specified in args, or if the checkpoint is
    there but invalid, the arguments will not be modified

    """
    load_dir = getattr(args, load_arg)

    if load_dir is None:
Jared Casper's avatar
Jared Casper committed
453
        print_rank_0('No load directory specified, using provided arguments.')
454
455
        return args

Jared Casper's avatar
Jared Casper committed
456
457
458
459
    model_state_dict, optim_state_dict, release = \
        _load_base_checkpoint(load_dir,
                              use_distributed_optimizer=args.use_distributed_optimizer,
                              rank0=True)
460

461
462
463
    # For args we only care about model state dict
    state_dict = model_state_dict
    
464
    if not state_dict:
Jared Casper's avatar
Jared Casper committed
465
        print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
466
467
468
        return args

    if 'args' not in state_dict:
Jared Casper's avatar
Jared Casper committed
469
        print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        return args

    checkpoint_args = state_dict['args']
    checkpoint_version = state_dict.get('checkpoint_version', 0)
    args.iteration = state_dict['iteration']

    def _set_arg(arg_name, old_arg_name=None, force=False):
        if not force and getattr(args, arg_name, None) is not None:
            return

        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name, None)

        if checkpoint_value is not None:
486
            print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            setattr(args, arg_name, checkpoint_value)

    _set_arg('num_layers')
    _set_arg('hidden_size')
    _set_arg('ffn_hidden_size')
    _set_arg('seq_length')
    _set_arg('num_attention_heads')
    _set_arg('kv_channels')
    _set_arg('max_position_embeddings')
    _set_arg('tokenizer_type')
    _set_arg('padded_vocab_size')
    if checkpoint_version < 3.0:
        _set_arg('tensor_model_parallel_size',
                 'model_parallel_size')
    else:
        _set_arg('tensor_model_parallel_size', force=True)
        _set_arg('pipeline_model_parallel_size', force=True)
        _set_arg('num_layers_per_virtual_pipeline_stage')
    return args

507
508

def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
509
510
511
512
513
514
515
516
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

517
    model = unwrap_model(model)
518

Jared Casper's avatar
Jared Casper committed
519
520
521
522
523
524
525
    model_state_dict, optim_state_dict, release = \
        _load_base_checkpoint(load_dir,
                              use_distributed_optimizer=args.use_distributed_optimizer,
                              rank0=False)

    if model_state_dict is None:
        return 0
526

Vijay Korthikanti's avatar
Vijay Korthikanti committed
527
    # set checkpoint version
528
    set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
529

530
531
532
533
534
    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
535
            iteration = model_state_dict['iteration']
536
        except KeyError:
Neel Kant's avatar
Neel Kant committed
537
            try:  # Backward compatible with older checkpoints
538
                iteration = model_state_dict['total_iters']
539
540
541
542
543
544
545
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name))
                sys.exit()

    # Check arguments.
mohammad's avatar
mohammad committed
546
547
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
548
549
    if 'args' in model_state_dict:
        checkpoint_args = model_state_dict['args']
550
        check_checkpoint_args(checkpoint_args)
551
552
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
mohammad's avatar
mohammad committed
553
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
554
555
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
556
557
558
559
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
560
    if len(model) == 1:
561
        model[0].load_state_dict(model_state_dict['model'], strict=strict)
562
563
564
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
565
            model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
566

Mostofa Patwary's avatar
Mostofa Patwary committed
567
568
569
570
    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)
571
572
573
574
575

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
576
                optimizer.load_state_dict(optim_state_dict['optimizer'])
577
            if opt_param_scheduler is not None:
578
                if 'lr_scheduler' in optim_state_dict: # backward compatbility
579
                    opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
580
                else:
581
                    opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
582
583
584
585
586
587
588
589
590
591
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
592
            if 'rng_state' in model_state_dict:
593
                # access rng_state for data parallel rank
594
                if args.data_parallel_random_init:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
595

596
                    rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
597
                else:
598
                    rng_state = model_state_dict['rng_state'][0]
599
600
601
602
603
604
605
606
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
607
                    rng_state['rng_tracker_states'])
608
            else:  # backward compatability
609
610
611
612
                random.setstate(model_state_dict['random_rng_state'])
                np.random.set_state(model_state_dict['np_rng_state'])
                torch.set_rng_state(model_state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
613
                # Check for empty states array
614
                if not model_state_dict['rng_tracker_states']:
615
616
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
617
                    model_state_dict['rng_tracker_states'])
618
        except KeyError:
619
            print_rank_0('Unable to load rng state from checkpoint {}. '
620
                         'Specify --no-load-rng or --finetune to prevent '
621
                         'attempting to load the rng state, '
622
623
624
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

Jared Casper's avatar
Jared Casper committed
625
626
627
628
629
630
    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')
631
632

    return iteration
Neel Kant's avatar
Neel Kant committed
633
634


635
636
637
def load_biencoder_checkpoint(model, only_query_model=False,
        only_context_model=False, custom_load_path=None):
    """
638
    selectively load retrieval models for indexing/retrieving
639
640
    from saved checkpoints
    """
Neel Kant's avatar
Neel Kant committed
641
642
643

    args = get_args()

644
    model = unwrap_model(model)
Neel Kant's avatar
Neel Kant committed
645

646
    load_path = custom_load_path if custom_load_path is not None else args.load
Neel Kant's avatar
Neel Kant committed
647
648
649
650
651

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

652
653
    checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
                                              args.use_distributed_optimizer,
654
655
                                              release=False)

Neel Kant's avatar
Neel Kant committed
656
657
658
659
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

660
    state_dict = torch.load(model_checkpoint_name, map_location='cpu')
661
    ret_state_dict = state_dict['model']
Neel Kant's avatar
Neel Kant committed
662
663

    if only_query_model:
664
        ret_state_dict.pop('context_model')
Mostofa Patwary's avatar
Mostofa Patwary committed
665
    if only_context_model:
666
        ret_state_dict.pop('query_model')
Neel Kant's avatar
Neel Kant committed
667

668
669
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
Neel Kant's avatar
Neel Kant committed
670
671
672
673
674
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

Neel Kant's avatar
Neel Kant committed
675
    return model