training.py 20.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
25
from megatron import get_args
Mohammad's avatar
Mohammad committed
26
27
from megatron import get_timers
from megatron import get_tensorboard_writer
28
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
29
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
30
31
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
32
33
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
34
from megatron.initialize import initialize_megatron
35
36
37
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
38
from megatron.model.realm_model import ICTBertModel
39
from megatron.utils import check_adlr_autoresume_termination
40
from megatron.utils import make_data_loader
41
from megatron.utils import report_memory
42
43


44
def pretrain(train_valid_test_dataset_provider, model_provider,
45
             forward_step_func, extra_args_provider=None, args_defaults={}):
46
47
48
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
49
50
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
51
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
52
        4) train the modle using the forward_step_func.
53
54

    Arguments:
55
56
57
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
58
59
60
61
62
63
64
65
66
67
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
68
69
    """

70
    # Initalize and get arguments, timers, and Tensorboard writer.
71
72
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
73

74
    args = get_args()
Mohammad's avatar
Mohammad committed
75
    timers = get_timers()
76
77

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
78
79
80
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
81
82

    # Data stuff.
83
84
85
86
87
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
Mohammad's avatar
Mohammad committed
88
89
90

    # Print setup timing.
    print_rank_0('done with setups ...')
91
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
92
    print_rank_0('training ...')
93
94

    iteration = 0
95
    if args.do_train and args.train_iters > 0:
96
97
        iteration, _ = train(forward_step_func,
                             model, optimizer, lr_scheduler,
Neel Kant's avatar
Neel Kant committed
98
                             train_data_iterator, valid_data_iterator)
Mohammad's avatar
Mohammad committed
99

100
101
102
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
103
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
104
                                   iteration, False)
105
106

    if args.save and iteration != 0:
107
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
108
109
110
111
112
113

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
114
                                   0, True)
115
116


Mohammad's avatar
Mohammad committed
117
def get_model(model_provider_func):
118
    """Build the model."""
Mohammad's avatar
Mohammad committed
119
    args = get_args()
120
121

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
122
    model = model_provider_func()
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training."""
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
140
141
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
142
143
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
144
        model = LocalDDP(model)
145
146
        return model

147
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
148
                              'Exiting.'.format(args.DDP_impl))
149
150


Mohammad's avatar
Mohammad committed
151
def get_optimizer(model):
152
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
153
    args = get_args()
154
155

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
156
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
157
158
159
160
161
162
163
164
165
166
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use Adam.
Mohammad's avatar
Mohammad committed
167
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
168
169
170
171
172
173
174
175

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
176
                                       'min_scale': args.min_scale,
177
178
179
180
181
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
182
def get_learning_rate_scheduler(optimizer):
183
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
184
    args = get_args()
185
186
187
188
189
190
191

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
Mohammad's avatar
Mohammad committed
192
    init_step = 0
193
194
195
196
197
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(
        optimizer,
        start_lr=args.lr,
        warmup_iter=warmup_iter,
Mohammad's avatar
Mohammad committed
198
        total_iters=num_iters,
199
200
201
202
203
204
205
206
207
        decay_style=args.lr_decay_style,
        last_iter=init_step,
        min_lr=args.min_lr,
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
208
def setup_model_and_optimizer(model_provider_func):
209
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
210
    args = get_args()
211

Mohammad's avatar
Mohammad committed
212
213
214
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
215
216

    if args.load is not None:
217
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
218
219
220
    else:
        args.iteration = 0

Neel Kant's avatar
Neel Kant committed
221
222
223
224
225
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

226
    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
227
        print("Initializing ICT from pretrained BERT model", flush=True)
228
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
229

230
231
232
    return model, optimizer, lr_scheduler


Mohammad's avatar
Mohammad committed
233
def backward_step(optimizer, model, loss):
234
    """Backward step."""
Mohammad's avatar
Mohammad committed
235
236
    args = get_args()
    timers = get_timers()
237
238

    # Backward pass.
239
    optimizer.zero_grad(set_grads_to_None=True)
240
241
242
243
244
245
246
247
248
249
250
    if args.fp16:
        optimizer.backward(loss, update_master_grads=False)
    else:
        loss.backward()

    # All-reduce if needed.
    if args.DDP_impl == 'local':
        timers('allreduce').start()
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
        timers('allreduce').stop()
251

252
253
254
    # Update master gradients.
    if args.fp16:
        optimizer.update_master_grads()
255

256
257
258
259
260
261
262
263
    # Clipping gradients helps prevent the exploding gradient.
    if args.clip_grad > 0:
        if not args.fp16:
            mpu.clip_grad_norm(model.parameters(), args.clip_grad)
        else:
            optimizer.clip_master_grads(args.clip_grad)


Mohammad's avatar
Mohammad committed
264
265
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
266
    """Single training step."""
Mohammad's avatar
Mohammad committed
267
268
    args = get_args()
    timers = get_timers()
269
270
271

    # Forward model for one step.
    timers('forward').start()
Mohammad's avatar
Mohammad committed
272
    loss, loss_reduced = forward_step_func(data_iterator, model)
273
274
    timers('forward').stop()

275
    # Calculate gradients, reduce across processes, and clip.
276
    timers('backward').start()
Mohammad's avatar
Mohammad committed
277
    backward_step(optimizer, model, loss)
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    timers('backward').stop()

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
        lr_scheduler.step()
    else:
        skipped_iter = 1

    return loss_reduced, skipped_iter


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
295
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
Mohammad's avatar
Mohammad committed
296
297
298
299
300
                 loss_scale, report_memory_flag):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
301
302
303
304
305
306
307

    # Update losses.
    for key in loss_dict:
        total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
308

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('backward')
    add_to_logging('allreduce')
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        for key in total_loss_dict:
            avg = total_loss_dict[key].item() / args.log_interval
            log_string += ' {}: {:.6E} |'.format(key, avg)
            total_loss_dict[key] = 0.0
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


356
def train(forward_step_func, model, optimizer, lr_scheduler,
357
          train_data_iterator, valid_data_iterator):
358
    """Train the model function."""
Mohammad's avatar
Mohammad committed
359
360
    args = get_args()
    timers = get_timers()
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
379
                                             lr_scheduler)
380
381
382
383
        skipped_iters += skipped_iter
        iteration += 1

        # Logging.
Mohammad's avatar
Mohammad committed
384
385
386
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
387
388
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
389
                                          iteration, loss_scale,
Mohammad's avatar
Mohammad committed
390
                                          report_memory_flag)
391
392

        # Autoresume
393
394
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
395
            check_adlr_autoresume_termination(iteration, model, optimizer,
396
                                              lr_scheduler)
397
398
399
400

        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
401
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
402
403
404
405
406
407

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
408
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
409
                                       iteration, False)
410
411

        if args.exit_interval and iteration % args.exit_interval == 0:
412
            torch.distributed.barrier()
413
414
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
Mohammad's avatar
Mohammad committed
415
416
417
            print_rank_0('rank: {} | time: {} | exiting the program at '
                         'iteration {}'.format(rank, time_str, iteration))
            sys.exit()
418
419
420
421

    return iteration, skipped_iters


Mohammad's avatar
Mohammad committed
422
def evaluate(forward_step_func, data_iterator, model, verbose=False):
423
    """Evaluation."""
Mohammad's avatar
Mohammad committed
424
    args = get_args()
425
426
427
428
429
430
431
432
433
434
435
436
437
438

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
            # Forward evaluation.
Mohammad's avatar
Mohammad committed
439
            _, loss_dict = forward_step_func(data_iterator, model)
440
441
442
            # Reduce across processes.
            for key in loss_dict:
                total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
Neel Kant's avatar
Neel Kant committed
443
                    loss_dict[key]
444
445
446
447
448
449
450
451
452
453
454
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters

    return total_loss_dict


def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
455
                               iteration, verbose=False):
456
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
457
458
459
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
    print_rank_0('-' * length)
    print_rank_0(string)
    print_rank_0('-' * length)


477
478
479
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
480
    args = get_args()
481

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Mohammad's avatar
Mohammad committed
530
    # Shift the start iterations.
531
532
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
Neel Kant's avatar
Neel Kant committed
533
            len(train_dataloader)
Mohammad's avatar
Mohammad committed
534
        print_rank_0('setting training data start iteration to {}'.
535
536
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
Mohammad's avatar
Mohammad committed
537
        start_iter_val = (args.iteration // args.eval_interval) * \
Neel Kant's avatar
Neel Kant committed
538
            args.eval_iters
539
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
Neel Kant's avatar
Neel Kant committed
540
            len(valid_dataloader)
Mohammad's avatar
Mohammad committed
541
        print_rank_0('setting validation data start iteration to {}'.
542
                     format(valid_dataloader.batch_sampler.start_iter))
543

544
545
546
    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
547
548
549
    else:
        train_data_iterator = None

550
551
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
552
    else:
553
        valid_data_iterator = None
554

555
556
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
557
558
559
    else:
        test_data_iterator = None

560
    return train_data_iterator, valid_data_iterator, test_data_iterator