checkpointing.py 25.9 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9
10
11

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np

import torch

12
13
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
14
15
16
17
from .global_vars import get_args
from .utils import (unwrap_model,
                    print_rank_0)

18

Vijay Korthikanti's avatar
Vijay Korthikanti committed
19
20
_CHECKPOINT_VERSION = None

liangjing's avatar
v1  
liangjing committed
21

Vijay Korthikanti's avatar
Vijay Korthikanti committed
22
23
def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
Jared Casper's avatar
Jared Casper committed
24
25
26
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
28
    _CHECKPOINT_VERSION = value

liangjing's avatar
v1  
liangjing committed
29

Vijay Korthikanti's avatar
Vijay Korthikanti committed
30
31
32
def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION
33

liangjing's avatar
v1  
liangjing committed
34

35
36
def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
37
    arguments and the one retrieved from checkpoint."""
38
39
    args = get_args()

liangjing's avatar
v1  
liangjing committed
40
    def _compare(arg_name, old_arg_name=None, default=None):
41
        if old_arg_name is not None:
liangjing's avatar
v1  
liangjing committed
42
            ckpt_arg_name = old_arg_name
43
        else:
liangjing's avatar
v1  
liangjing committed
44
45
46
47
48
            ckpt_arg_name = arg_name
        if default is not None:
            checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default)
        else:
            checkpoint_value = getattr(checkpoint_args, ckpt_arg_name)
49
50
51
52
53
54
55
56
57
        args_value = getattr(args, arg_name)
        error_message = '{} value from checkpoint ({}) is not equal to the ' \
                        'input argument value ({}).'.format(
                            arg_name, checkpoint_value, args_value)
        assert checkpoint_value == args_value, error_message

    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
liangjing's avatar
v1  
liangjing committed
58
    _compare('add_position_embedding', default=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
59
    if args.vocab_file:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
60
        _compare('max_position_embeddings')
61
62
63
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
65
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
66
67
68
69
70
71
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
    if get_checkpoint_version() >= 3.0:
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')
72

liangjing's avatar
v1  
liangjing committed
73

74
75
76
def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
liangjing's avatar
v1  
liangjing committed
77
    os.makedirs(dirname, exist_ok = True)
78
79


liangjing's avatar
v1  
liangjing committed
80
81
82
def get_checkpoint_name(checkpoints_path, iteration, release=False,
                        pipeline_parallel=None,
                        tensor_rank=None, pipeline_rank=None):
Jared Casper's avatar
Jared Casper committed
83
    """Determine the directory name for this rank's checkpoint."""
84
85
86
87
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
88

89
    # Use both the tensor and pipeline MP rank.
Jared Casper's avatar
Jared Casper committed
90
    if pipeline_parallel is None:
91
        pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
92
    if tensor_rank is None:
93
        tensor_rank = mpu.get_tensor_model_parallel_rank()
94
    if pipeline_rank is None:
95
        pipeline_rank = mpu.get_pipeline_model_parallel_rank()
96

97
98
99
    # Use both the tensor and pipeline MP rank. If using the distributed
    # optimizer, then the optimizer's path must additionally include the
    # data parallel rank.
Jared Casper's avatar
Jared Casper committed
100
    if not pipeline_parallel:
101
        common_path = os.path.join(checkpoints_path, directory,
102
                            f'mp_rank_{tensor_rank:02d}')
103
104
    else:
        common_path = os.path.join(checkpoints_path, directory,
105
                        f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
106

liangjing's avatar
v1  
liangjing committed
107
108
109
110
111
112
113
    return os.path.join(common_path, "model_optim_rng.pt")


def get_distributed_optimizer_checkpoint_name(model_checkpoint_name):
    return os.path.join(os.path.dirname(model_checkpoint_name),
                        "distrib_optim.pt")

114

liangjing's avatar
v1  
liangjing committed
115
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
116
117
118
119
120
121
122
123
124
    """Finds the checkpoint for rank 0 without knowing if we are using
    pipeline parallelism or not.

    Since the checkpoint naming scheme changes if pipeline parallelism
    is present, we need to look for both naming schemes if we don't
    know if the checkpoint has pipeline parallelism.
    """

    # Look for checkpoint with no pipelining
liangjing's avatar
v1  
liangjing committed
125
126
127
128
129
    filename = get_checkpoint_name(checkpoints_path, iteration, release,
                                   pipeline_parallel=False,
                                   tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filename):
        return filename
130
131

    # Look for checkpoint with pipelining
liangjing's avatar
v1  
liangjing committed
132
133
134
135
136
    filename = get_checkpoint_name(checkpoints_path, iteration, release,
                                   pipeline_parallel=True,
                                   tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filename):
        return filename
137

Jared Casper's avatar
Jared Casper committed
138
    return None, None
139

liangjing's avatar
v1  
liangjing committed
140

141
def get_checkpoint_tracker_filename(checkpoints_path):
142

143
144
145
146
147
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

166
    # Get the max iteration retrieved across the ranks.
167
168
169
170
171
172
173
174
175
    if torch.distributed.is_initialized():
        iters_cuda = torch.cuda.LongTensor([iteration])
        torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
        max_iter = iters_cuda[0].item()

        # We should now have all the same iteration.
        # If not, print a warning and chose the maximum
        # iteration across all ranks.
        if iteration != max_iter:
liangjing's avatar
v1  
liangjing committed
176
            rank = torch.distributed.get_rank()
177
178
179
180
181
182
183
184
185
            print('WARNING: on rank {} found iteration {} in the '
                  'metadata while max iteration across the ranks '
                  'is {}, replacing it with max iteration.'.format(
                      rank, iteration, max_iter), flush=True)
    else:
        # When loading a checkpoint outside of training (for example,
        # when editing it), we might not have torch distributed
        # initialized, in this case, just assume we have the latest
        max_iter = iteration
186
187
188
    return max_iter, release


189
190
def get_rng_state():
    """ collect rng state across data parallel ranks """
191
    args = get_args()
192
193
194
195
196
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
197
        'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
198
199
200

    rng_state_list = None
    if torch.distributed.is_initialized() and \
201
            mpu.get_data_parallel_world_size() > 1 and \
202
            args.data_parallel_random_init:
203
        rng_state_list = \
204
            [None for i in range(mpu.get_data_parallel_world_size())]
205
        torch.distributed.all_gather_object(
206
            rng_state_list,
207
            rng_state,
208
            group=mpu.get_data_parallel_group())
209
210
211
212
213
214
    else:
        rng_state_list = [rng_state]

    return rng_state_list


215
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
216
217
218
219
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
220
    model = unwrap_model(model)
221

Jared Casper's avatar
Jared Casper committed
222
223
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
224

225
    # Collect rng state across data parallel ranks.
226
227
    rng_state = get_rng_state()

liangjing's avatar
v1  
liangjing committed
228
229
230
231
232
233
234
235
236
    # Checkpoint name.
    checkpoint_name = get_checkpoint_name(args.save, iteration)

    # Save distributed optimizer's custom parameter state.
    if args.use_distributed_optimizer:
        optim_checkpoint_name = \
            get_distributed_optimizer_checkpoint_name(checkpoint_name)
        ensure_directory_exists(optim_checkpoint_name)
        optimizer.save_parameter_state(optim_checkpoint_name)
237

Lawrence McAfee's avatar
Lawrence McAfee committed
238
    # Collect args, model, RNG.
239
    if not torch.distributed.is_initialized() \
240
       or mpu.get_data_parallel_rank() == 0:
241
242

        # Arguments, iteration, and model.
liangjing's avatar
v1  
liangjing committed
243
244
245
246
        state_dict = {}
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 3.0
        state_dict['iteration'] = iteration
247
        if len(model) == 1:
liangjing's avatar
v1  
liangjing committed
248
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
249
250
        else:
            for i in range(len(model)):
251
                mpu.set_virtual_pipeline_model_parallel_rank(i)
liangjing's avatar
v1  
liangjing committed
252
                state_dict['model%d' % i] = \
Lawrence McAfee's avatar
Lawrence McAfee committed
253
                    model[i].state_dict_for_save_checkpoint()
254

liangjing's avatar
v1  
liangjing committed
255
256
257
258
259
260
261
262
        # Optimizer stuff.
        if not args.no_save_optim:
            if optimizer is not None:
                state_dict['optimizer'] = optimizer.state_dict()
            if opt_param_scheduler is not None:
                state_dict['opt_param_scheduler'] = \
                    opt_param_scheduler.state_dict()

263
264
        # RNG states.
        if not args.no_save_rng:
liangjing's avatar
v1  
liangjing committed
265
            state_dict["rng_state"] = rng_state
266

liangjing's avatar
v1  
liangjing committed
267
268
269
        # Save.
        ensure_directory_exists(checkpoint_name)
        torch.save(state_dict, checkpoint_name)
270
271

    # Wait so everyone is done (necessary)
Jared Casper's avatar
Jared Casper committed
272
273
274
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

liangjing's avatar
v1  
liangjing committed
275
276
    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}' \
                 .format(iteration, args.save))
Jared Casper's avatar
Jared Casper committed
277

278
    # And update the latest iteration
liangjing's avatar
v1  
liangjing committed
279
280
    if not torch.distributed.is_initialized() \
       or torch.distributed.get_rank() == 0:
281
282
283
284
285
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
Jared Casper's avatar
Jared Casper committed
286
287
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
288

liangjing's avatar
v1  
liangjing committed
289

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t
327

liangjing's avatar
v1  
liangjing committed
328

Mostofa Patwary's avatar
Mostofa Patwary committed
329
330
331
332
333
def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
334
335
336
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
Mostofa Patwary's avatar
Mostofa Patwary committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
liangjing's avatar
v1  
liangjing committed
357
358
                     " checkpoint version {}".format(checkpoint_version))

Mostofa Patwary's avatar
Mostofa Patwary committed
359

liangjing's avatar
v1  
liangjing committed
360
def _load_base_checkpoint(load_dir, rank0=False):
361
362
363
    """ Load the base state_dict from the given directory

    If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
364
    """
365
366

    # Read the tracker file and set the iteration.
367
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
368

369
    # If no tracker file, return nothing
370
    if not os.path.isfile(tracker_filename):
371
372
373
374
375
        if not rank0:
            print_rank_0('WARNING: could not find the metadata file {} '.format(
                tracker_filename))
            print_rank_0('    will not load any checkpoints and will start from '
                         'random')
liangjing's avatar
v1  
liangjing committed
376
        return None, "", False
377
378
379

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
380
    iteration, release = read_metadata(tracker_filename)
381
382

    # Checkpoint.
383
    if rank0:
liangjing's avatar
v1  
liangjing committed
384
        checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
385
    else:
liangjing's avatar
v1  
liangjing committed
386
        checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
387
388
389
390
        if release:
            print_rank_0(f' loading release checkpoint from {load_dir}')
        else:
            print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
391
392
393

    # Load the checkpoint.
    try:
liangjing's avatar
v1  
liangjing committed
394
        state_dict = torch.load(checkpoint_name, map_location='cpu')
395
    except ModuleNotFoundError:
mohammad's avatar
mohammad committed
396
        from megatron.fp16_deprecated import loss_scaler
397
        # For backward compatibility.
398
399
        if not rank0:
            print_rank_0(' > deserializing using the old code structure ...')
400
        sys.modules['fp16.loss_scaler'] = sys.modules[
mohammad's avatar
mohammad committed
401
402
403
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
liangjing's avatar
v1  
liangjing committed
404
        state_dict = torch.load(checkpoint_name, map_location='cpu')
405
        sys.modules.pop('fp16.loss_scaler', None)
mohammad's avatar
mohammad committed
406
        sys.modules.pop('megatron.fp16.loss_scaler', None)
407
    except BaseException as e:
408
        print_rank_0('could not load the checkpoint')
409
        print_rank_0(e)
410
411
        sys.exit()

liangjing's avatar
v1  
liangjing committed
412
413
    return state_dict, checkpoint_name, release

414
415

def load_args_from_checkpoint(args, load_arg='load'):
416
417
418
419
420
    """Set required arguments from the checkpoint specified in the
    arguments.

    Will overwrite arguments that have a non-None default value, but
    will leave any arguments that default to None as set.
421
422
423
424
425
426
427
428
429
430

    Returns the same args NameSpace with the new values added/updated.

    If no checkpoint is specified in args, or if the checkpoint is
    there but invalid, the arguments will not be modified

    """
    load_dir = getattr(args, load_arg)

    if load_dir is None:
Jared Casper's avatar
Jared Casper committed
431
        print_rank_0('No load directory specified, using provided arguments.')
432
433
        return args

liangjing's avatar
v1  
liangjing committed
434
    state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True)
435

liangjing's avatar
v1  
liangjing committed
436
    # Args.
437
    if not state_dict:
Jared Casper's avatar
Jared Casper committed
438
        print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
439
440
441
        return args

    if 'args' not in state_dict:
Jared Casper's avatar
Jared Casper committed
442
        print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
443
444
445
446
447
448
        return args

    checkpoint_args = state_dict['args']
    checkpoint_version = state_dict.get('checkpoint_version', 0)
    args.iteration = state_dict['iteration']

liangjing's avatar
v1  
liangjing committed
449
450
451
452
    # One-off conversion for foundation models
    if hasattr(checkpoint_args, 'disable_bias_linear'):
        setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear'))

453
454
455
456
457
458
459
460
461
462
    def _set_arg(arg_name, old_arg_name=None, force=False):
        if not force and getattr(args, arg_name, None) is not None:
            return

        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name, None)

        if checkpoint_value is not None:
463
            print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
464
            setattr(args, arg_name, checkpoint_value)
liangjing's avatar
v1  
liangjing committed
465
466
        else:
            print_rank_0(f"Checkpoint did not provide arguments {arg_name}")
467
468
469
470
471
472
473
474

    _set_arg('num_layers')
    _set_arg('hidden_size')
    _set_arg('ffn_hidden_size')
    _set_arg('seq_length')
    _set_arg('num_attention_heads')
    _set_arg('kv_channels')
    _set_arg('max_position_embeddings')
liangjing's avatar
v1  
liangjing committed
475
476
477
478
479
480
481
482
    _set_arg('position_embedding_type', force=True)
    _set_arg('add_position_embedding', force=True)
    _set_arg('use_rotary_position_embeddings', force=True)
    _set_arg('rotary_percent', force=True)
    _set_arg('add_bias_linear', force=True)
    _set_arg('swiglu', force=True)
    _set_arg('untie_embeddings_and_output_weights', force=True)
    _set_arg('apply_layernorm_1p', force=True)
483
484
485
486
487
488
489
490
    _set_arg('tokenizer_type')
    _set_arg('padded_vocab_size')
    if checkpoint_version < 3.0:
        _set_arg('tensor_model_parallel_size',
                 'model_parallel_size')
    else:
        _set_arg('tensor_model_parallel_size', force=True)
        _set_arg('pipeline_model_parallel_size', force=True)
liangjing's avatar
v1  
liangjing committed
491
        _set_arg('virtual_pipeline_model_parallel_size', force=True)
492
        _set_arg('num_layers_per_virtual_pipeline_stage')
liangjing's avatar
v1  
liangjing committed
493
    return args, checkpoint_args
494

495
496

def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
497
498
499
500
501
502
503
504
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

505
    model = unwrap_model(model)
506

liangjing's avatar
v1  
liangjing committed
507
    state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False)
Jared Casper's avatar
Jared Casper committed
508

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
509
    # Checkpoint not loaded.
liangjing's avatar
v1  
liangjing committed
510
    if state_dict is None:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
511
512
513
514
515
516
517
518

        # Conditionally exit at this point.
        if args.exit_on_missing_checkpoint:
            print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
            torch.distributed.barrier()
            sys.exit()

        # Iteration defaults to 0.
Jared Casper's avatar
Jared Casper committed
519
        return 0
520

liangjing's avatar
v1  
liangjing committed
521
522
    # Set checkpoint version.
    set_checkpoint_version(state_dict.get('checkpoint_version', 0))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
523

524
525
526
527
528
    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
liangjing's avatar
v1  
liangjing committed
529
            iteration = state_dict['iteration']
530
        except KeyError:
Neel Kant's avatar
Neel Kant committed
531
            try:  # Backward compatible with older checkpoints
liangjing's avatar
v1  
liangjing committed
532
                iteration = state_dict['total_iters']
533
534
535
536
537
538
539
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name))
                sys.exit()

    # Check arguments.
mohammad's avatar
mohammad committed
540
541
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
liangjing's avatar
v1  
liangjing committed
542
543
    if 'args' in state_dict and not args.finetune:
        checkpoint_args = state_dict['args']
544
        check_checkpoint_args(checkpoint_args)
545
546
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
mohammad's avatar
mohammad committed
547
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
548
549
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
550
551
552
553
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
554
    if len(model) == 1:
liangjing's avatar
v1  
liangjing committed
555
        model[0].load_state_dict(state_dict['model'], strict=strict)
556
557
    else:
        for i in range(len(model)):
558
            mpu.set_virtual_pipeline_model_parallel_rank(i)
liangjing's avatar
v1  
liangjing committed
559
            model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
560

liangjing's avatar
v1  
liangjing committed
561
    # Fix up query/key/value matrix ordering if needed.
Mostofa Patwary's avatar
Mostofa Patwary committed
562
563
564
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)
565
566
567
568

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
liangjing's avatar
v1  
liangjing committed
569
            # Load state dict.
570
            if optimizer is not None:
liangjing's avatar
v1  
liangjing committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
                optimizer.load_state_dict(state_dict['optimizer'])

            # Load distributed optimizer's custom parameter state.
            if args.use_distributed_optimizer:
                tracker_filename = get_checkpoint_tracker_filename(load_dir)
                iteration, release = read_metadata(tracker_filename)
                model_checkpoint_name = \
                    get_checkpoint_name(load_dir, iteration, release)
                optim_checkpoint_name = \
                    get_distributed_optimizer_checkpoint_name(
                        model_checkpoint_name)
                optimizer.load_parameter_state(optim_checkpoint_name)

            # Load scheduler.
585
            if opt_param_scheduler is not None:
liangjing's avatar
v1  
liangjing committed
586
587
                if 'lr_scheduler' in state_dict: # backward compatbility
                    opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
588
                else:
liangjing's avatar
v1  
liangjing committed
589
                    opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
590
591
592
593
594
595
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()
596
    else:
liangjing's avatar
v1  
liangjing committed
597
        if (args.fp16 or args.bf16) and optimizer is not None:
598
            optimizer.reload_model_params()
599
600
601
602

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
liangjing's avatar
v1  
liangjing committed
603
            if 'rng_state' in state_dict:
604
                # access rng_state for data parallel rank
605
                if args.data_parallel_random_init:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
606

liangjing's avatar
v1  
liangjing committed
607
                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
608
                else:
liangjing's avatar
v1  
liangjing committed
609
                    rng_state = state_dict['rng_state'][0]
610
611
612
613
614
615
616
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
617
                tensor_parallel.get_cuda_rng_tracker().set_states(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
618
                    rng_state['rng_tracker_states'])
619
            else:  # backward compatability
liangjing's avatar
v1  
liangjing committed
620
621
622
623
                random.setstate(state_dict['random_rng_state'])
                np.random.set_state(state_dict['np_rng_state'])
                torch.set_rng_state(state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
624
                # Check for empty states array
liangjing's avatar
v1  
liangjing committed
625
                if not state_dict['rng_tracker_states']:
626
                    raise KeyError
627
                tensor_parallel.get_cuda_rng_tracker().set_states(
liangjing's avatar
v1  
liangjing committed
628
                    state_dict['rng_tracker_states'])
629
        except KeyError:
630
            print_rank_0('Unable to load rng state from checkpoint {}. '
631
                         'Specify --no-load-rng or --finetune to prevent '
632
                         'attempting to load the rng state, '
633
634
635
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

Jared Casper's avatar
Jared Casper committed
636
637
638
639
640
641
    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')
642
643

    return iteration
Neel Kant's avatar
Neel Kant committed
644
645


646
def load_biencoder_checkpoint(model, only_query_model=False,
liangjing's avatar
v1  
liangjing committed
647
                              only_context_model=False, custom_load_path=None):
648
    """
649
    selectively load retrieval models for indexing/retrieving
650
651
    from saved checkpoints
    """
Neel Kant's avatar
Neel Kant committed
652
653
654

    args = get_args()

655
    model = unwrap_model(model)
Neel Kant's avatar
Neel Kant committed
656

657
    load_path = custom_load_path if custom_load_path is not None else args.load
Neel Kant's avatar
Neel Kant committed
658
659
660
661
662

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

liangjing's avatar
v1  
liangjing committed
663
664
665
    checkpoint_name = get_checkpoint_name(load_path, iteration,
                                          args.use_distributed_optimizer,
                                          release=False)
666

667
    if mpu.get_data_parallel_rank() == 0:
Neel Kant's avatar
Neel Kant committed
668
669
670
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

liangjing's avatar
v1  
liangjing committed
671
    state_dict = torch.load(checkpoint_name, map_location='cpu')
672
    ret_state_dict = state_dict['model']
Neel Kant's avatar
Neel Kant committed
673
674

    if only_query_model:
675
        ret_state_dict.pop('context_model')
Mostofa Patwary's avatar
Mostofa Patwary committed
676
    if only_context_model:
677
        ret_state_dict.pop('query_model')
Neel Kant's avatar
Neel Kant committed
678

679
680
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
Neel Kant's avatar
Neel Kant committed
681
682
    torch.distributed.barrier()

683
    if mpu.get_data_parallel_rank() == 0:
Neel Kant's avatar
Neel Kant committed
684
685
        print(' successfully loaded {}'.format(checkpoint_name))

Neel Kant's avatar
Neel Kant committed
686
    return model