checkpointing.py 20.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Input/output checkpointing."""

import os
import random
import sys
import numpy as np

import torch

25
26
27
from megatron import (get_args,
                      mpu,
                      print_rank_0,
28
29
                      update_num_microbatches,
                      utils)
30

Vijay Korthikanti's avatar
Vijay Korthikanti committed
31
32
33
34
_CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
Jared Casper's avatar
Jared Casper committed
35
36
37
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
39
40
41
42
    _CHECKPOINT_VERSION = value

def get_checkpoint_version():
    global _CHECKPOINT_VERSION
    return _CHECKPOINT_VERSION
43
44
45

def check_checkpoint_args(checkpoint_args):
    """Ensure fixed arguments for a model are the same for the input
46
    arguments and the one retrieved from checkpoint."""
47
48
    args = get_args()

49
50
51
52
53
    def _compare(arg_name, old_arg_name=None):
        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name)
54
55
56
57
58
59
60
61
62
        args_value = getattr(args, arg_name)
        error_message = '{} value from checkpoint ({}) is not equal to the ' \
                        'input argument value ({}).'.format(
                            arg_name, checkpoint_value, args_value)
        assert checkpoint_value == args_value, error_message

    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
    if args.vocab_file:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        _compare('max_position_embeddings')
65
66
67
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
69
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
70
71
72
73
74
75
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
    if get_checkpoint_version() >= 3.0:
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')
76
77
78
79
80
81
82
83

def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)


84
85
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
                         release=False):
86
87
88
89
90
    """A unified checkpoint name."""
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
91
92
93
    # Use both the tensor and pipeline MP rank. If using the distributed
    # optimizer, then the optimizer's path must additionally include the
    # data parallel rank.
94
95
96
97
98
99
100
101
102
103
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        common_path = os.path.join(checkpoints_path, directory,
                                   'mp_rank_{:02d}'.format(
                                       mpu.get_tensor_model_parallel_rank()))
    else:
        common_path = os.path.join(checkpoints_path, directory,
                                   'mp_rank_{:02d}_{:03d}'.format(
                                       mpu.get_tensor_model_parallel_rank(),
                                       mpu.get_pipeline_model_parallel_rank()))

104
    if use_distributed_optimizer:
105
        model_name = os.path.join(common_path, "model_rng.pt")
106
107
108
109
        optim_name = os.path.join(
            common_path + "_%03d" % mpu.get_data_parallel_rank(),
            "optim.pt")
    else:
110
        model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
111
    return model_name, optim_name
112
113
114
115
116
117
118
119


def get_checkpoint_tracker_filename(checkpoints_path):
    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

138
139
140
141
    # Get the max iteration retrieved across the ranks.
    iters_cuda = torch.cuda.LongTensor([iteration])
    torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
    max_iter = iters_cuda[0].item()
142
143
144
145

    # We should now have all the same iteration.
    # If not, print a warning and chose the maximum
    # iteration across all ranks.
146
147
148
149
150
    if iteration != max_iter:
        print('WARNING: on rank {} found iteration {} in the '
              'metadata while max iteration across the ranks '
              'is {}, replacing it with max iteration.'.format(
                  rank, iteration, max_iter), flush=True)
151
152
153
    return max_iter, release


154
155
def get_rng_state():
    """ collect rng state across data parallel ranks """
156
    args = get_args()
157
158
159
160
161
162
163
164
165
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}

    rng_state_list = None
    if torch.distributed.is_initialized() and \
166
167
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
168
169
170
        rng_state_list = \
            [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.all_gather_object(
171
            rng_state_list,
172
            rng_state,
173
174
175
176
177
178
179
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]

    return rng_state_list


180
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
181
182
183
184
    """Save a model checkpoint."""
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
185
    model = utils.unwrap_model(model)
186

Jared Casper's avatar
Jared Casper committed
187
188
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
189

190
    # Collect rng state across data parallel ranks.
191
192
    rng_state = get_rng_state()

193
194
    # Checkpoint file names.
    model_checkpoint_name, optim_checkpoint_name = \
195
196
        get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)

197
198
199
    # Save args, model, RNG.
    if not torch.distributed.is_initialized() \
       or mpu.get_data_parallel_rank() == 0:
200
201
202
203

        # Arguments, iteration, and model.
        state_dict = {}
        state_dict['args'] = args
204
        state_dict['checkpoint_version'] = 3.0
205
        state_dict['iteration'] = iteration
206
207
208
209
210
211
        if len(model) == 1:
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
212
213
214

        # RNG states.
        if not args.no_save_rng:
215
            state_dict["rng_state"] = rng_state
216
217

        # Save.
218
219
220
        ensure_directory_exists(model_checkpoint_name)
        torch.save(state_dict, model_checkpoint_name)

221
222
    # Save optimizer state. (Optimizer is saved separately from the model, due
    # to the conflicting data pattern when using the distributed optimizer.)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    if not args.no_save_optim \
       and (not torch.distributed.is_initialized()
            or mpu.get_data_parallel_rank() == 0
            or args.use_distributed_optimizer):

        # Optimizer stuff.
        state_dict = {}
        if optimizer is not None:
            state_dict['optimizer'] = optimizer.state_dict()
        if opt_param_scheduler is not None:
            state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()

        # Save.
        ensure_directory_exists(optim_checkpoint_name)
        torch.save(state_dict, optim_checkpoint_name)
238
239

    # Wait so everyone is done (necessary)
Jared Casper's avatar
Jared Casper committed
240
241
242
243
244
245
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

246
    # And update the latest iteration
Jared Casper's avatar
Jared Casper committed
247
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
248
249
250
251
252
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
Jared Casper's avatar
Jared Casper committed
253
254
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t
293

Mostofa Patwary's avatar
Mostofa Patwary committed
294
295
296
297
298
def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
299
300
301
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
Mostofa Patwary's avatar
Mostofa Patwary committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

324
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
325
326
327
328
329
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
330
    args = get_args()
331
    load_dir = getattr(args, load_arg)
332

333
    model = utils.unwrap_model(model)
334

335
    # Read the tracker file and set the iteration.
336
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
337
338
339
340
341
342
343
344
345
346
347

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
        print_rank_0('WARNING: could not find the metadata file {} '.format(
            tracker_filename))
        print_rank_0('    will not load any checkpoints and will start from '
                     'random')
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
348
    iteration, release = read_metadata(tracker_filename)
349
350

    # Checkpoint.
351
    model_checkpoint_name, optim_checkpoint_name = \
352
353
354
        get_checkpoint_names(load_dir, iteration,
                             args.use_distributed_optimizer,
                             release)
Jared Casper's avatar
Jared Casper committed
355
    print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
356
357
358

    # Load the checkpoint.
    try:
359
360
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
        optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
361
    except ModuleNotFoundError:
mohammad's avatar
mohammad committed
362
        from megatron.fp16_deprecated import loss_scaler
363
364
365
        # For backward compatibility.
        print_rank_0(' > deserializing using the old code structure ...')
        sys.modules['fp16.loss_scaler'] = sys.modules[
mohammad's avatar
mohammad committed
366
367
368
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
369
370
        model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
        optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
371
        sys.modules.pop('fp16.loss_scaler', None)
mohammad's avatar
mohammad committed
372
        sys.modules.pop('megatron.fp16.loss_scaler', None)
373
    except BaseException as e:
374
        print_rank_0('could not load the checkpoint')
375
        print_rank_0(e)
376
377
        sys.exit()

378
    # Set checkpoint version.
379
    set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
380

381
382
383
384
385
    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
386
            iteration = model_state_dict['iteration']
387
        except KeyError:
Neel Kant's avatar
Neel Kant committed
388
            try:  # Backward compatible with older checkpoints
389
                iteration = model_state_dict['total_iters']
390
391
392
393
394
395
396
            except KeyError:
                print_rank_0('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name))
                sys.exit()

    # Check arguments.
mohammad's avatar
mohammad committed
397
398
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
399
400
    if 'args' in model_state_dict:
        checkpoint_args = model_state_dict['args']
401
        check_checkpoint_args(checkpoint_args)
402
403
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
mohammad's avatar
mohammad committed
404
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
405
406
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
407
408
409
410
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
411
    if len(model) == 1:
412
        model[0].load_state_dict(model_state_dict['model'], strict=strict)
413
414
415
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
416
            model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
417

Mostofa Patwary's avatar
Mostofa Patwary committed
418
419
420
421
    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)
422
423
424
425
426

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
427
                optimizer.load_state_dict(optim_state_dict['optimizer'])
428
            if opt_param_scheduler is not None:
429
                if 'lr_scheduler' in optim_state_dict: # backward compatbility
430
                    opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
431
                else:
432
                    opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
433
434
435
436
437
438
439
440
441
442
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
443
            if 'rng_state' in model_state_dict:
444
                # access rng_state for data parallel rank
445
                if args.data_parallel_random_init:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
446

447
                    rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
448
                else:
449
                    rng_state = model_state_dict['rng_state'][0]
450
451
452
453
454
455
456
457
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
458
                    rng_state['rng_tracker_states'])
459
            else:  # backward compatability
460
461
462
463
                random.setstate(model_state_dict['random_rng_state'])
                np.random.set_state(model_state_dict['np_rng_state'])
                torch.set_rng_state(model_state_dict['torch_rng_state'])
                torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
464
                # Check for empty states array
465
                if not model_state_dict['rng_tracker_states']:
466
467
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
468
                    model_state_dict['rng_tracker_states'])
469
        except KeyError:
470
            print_rank_0('Unable to load rng state from checkpoint {}. '
471
                         'Specify --no-load-rng or --finetune to prevent '
472
                         'attempting to load the rng state, '
473
474
475
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

Jared Casper's avatar
Jared Casper committed
476
477
478
479
480
481
    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')
482
483

    return iteration
Neel Kant's avatar
Neel Kant committed
484
485


486
487
488
489
490
491
def load_biencoder_checkpoint(model, only_query_model=False,
        only_context_model=False, custom_load_path=None):
    """
    selectively load retrieval models for indexing/retrieving 
    from saved checkpoints
    """
Neel Kant's avatar
Neel Kant committed
492
493
494

    args = get_args()

495
    model = utils.unwrap_model(model)
Neel Kant's avatar
Neel Kant committed
496

497
    load_path = custom_load_path if custom_load_path is not None else args.load
Neel Kant's avatar
Neel Kant committed
498
499
500
501
502

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

503
504
505
    checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
                                              args.use_distributed_optimizer,
                                              False)
Neel Kant's avatar
Neel Kant committed
506
507
508
509
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

510
    state_dict = torch.load(model_checkpoint_name, map_location='cpu')
511
    ret_state_dict = state_dict['model']
Neel Kant's avatar
Neel Kant committed
512
513

    if only_query_model:
514
        ret_state_dict.pop('context_model')
Mostofa Patwary's avatar
Mostofa Patwary committed
515
    if only_context_model:
516
        ret_state_dict.pop('query_model')
Neel Kant's avatar
Neel Kant committed
517

518
519
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
Neel Kant's avatar
Neel Kant committed
520
521
522
523
524
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

Neel Kant's avatar
Neel Kant committed
525
    return model
526