training.py 41.1 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
28
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
29
from megatron import get_args
Mohammad's avatar
Mohammad committed
30
31
from megatron import get_timers
from megatron import get_tensorboard_writer
32
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
33
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
34
from megatron import is_last_rank
mohammad's avatar
mohammad committed
35
from megatron import update_num_microbatches
36
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
37
from megatron import print_rank_0
38
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
39
40
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
41
42
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.data.data_loaders import build_pretraining_data_loader
50
from megatron.utils import report_memory
51
52


53
54
55
56
57
58
59
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


60
def pretrain(train_valid_test_dataset_provider, model_provider,
61
             forward_step_func, extra_args_provider=None, args_defaults={}):
62
63
64
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
65
66
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
67
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
68
        4) train the modle using the forward_step_func.
69
70

    Arguments:
71
72
73
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
74
75
76
77
78
79
80
81
82
83
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
84
85
    """

86
    # Initalize and get arguments, timers, and Tensorboard writer.
87
88
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
89

90
91
92
93
94
95
96
97
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
98
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
99
100
101
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

102
    args = get_args()
Mohammad's avatar
Mohammad committed
103
    timers = get_timers()
104
105

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
106
107
108
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
109
110
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
111
112

    # Data stuff.
113
114
115
116
117
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
mshoeybi's avatar
mshoeybi committed
118
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
119
120
121

    # Print setup timing.
    print_rank_0('done with setups ...')
122
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
123
    print_rank_0('training ...')
124
125

    iteration = 0
126
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
127
128
129
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
130
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
131

132
133
134
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
135
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
136
                                   iteration, False)
137
138

    if args.save and iteration != 0:
139
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
140
141
142
143
144
145

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
146
                                   0, True)
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
164
165
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
166
167
            iterations += 1
        # Reset
168
        update_num_microbatches(0, consistency_check=False)
169
170
171
172
173
174
175
176
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

177

Mohammad's avatar
Mohammad committed
178
def get_model(model_provider_func):
179
    """Build the model."""
Mohammad's avatar
Mohammad committed
180
    args = get_args()
181
182

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
183
    model = model_provider_func()
184
185
186

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
187
        print(' > number of parameters on (tensor, pipeline) '
188
              'model parallel rank ({}, {}): {}'.format(
189
190
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
191
192
193
194
195
196
197
198
199
200
201
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
202
203
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
204
205
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
206
        model = LocalDDP(model)
207
208
        return model

209
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
210
                              'Exiting.'.format(args.DDP_impl))
211
212


Mohammad's avatar
Mohammad committed
213
def get_optimizer(model):
214
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
215
    args = get_args()
216
217

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
218
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
219
220
221
222
223
224
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
225
226
            if not hasattr(param, 'tensor_model_parallel'):
                param.tensor_model_parallel = False
227
228

    # Use Adam.
229
230
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
231
232
233
234
235
236
237
238

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
239
                                       'min_scale': args.min_scale,
240
241
242
243
244
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
245
def get_learning_rate_scheduler(optimizer):
246
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
247
    args = get_args()
248

249
250
251
252
253
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
254
255
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
256
257
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
258
259
260
261
262
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
263
        update_train_iters(args)
264
265
266
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
267
268
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
269
270
        else:
            warmup_steps = args.lr_warmup_samples
271
    else:
272
273
274
        raise Exception(
            'either train-iters or train-samples should be provided.')

275
276
    lr_scheduler = AnnealingLR(
        optimizer,
277
        max_lr=args.lr,
278
        min_lr=args.min_lr,
279
280
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
281
        decay_style=args.lr_decay_style,
282
283
284
285
286
287
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
288
def setup_model_and_optimizer(model_provider_func):
289
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
290
    args = get_args()
291

Mohammad's avatar
Mohammad committed
292
293
294
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
295
296

    if args.load is not None:
297
298
299
300
301
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
        timers('load checkpoint').start()
302
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
303
304
305
        torch.distributed.barrier()
        timers('load checkpoint').stop()
        timers.log(['load checkpoint'])
306
307
308
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
309
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
310
311
312
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
313
314
315
316
317
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

318
319
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
320
        print("Initializing ICT from pretrained BERT model", flush=True)
321
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
322

323
324
325
    return model, optimizer, lr_scheduler


326
327
328
329
330
331
332
333
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
334
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
335
336
337
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
338
339
                                       device=torch.cuda.current_device(),
                                       dtype=args.params_dtype)
340
341
342
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
343
344
                                       device=torch.cuda.current_device(),
                                       dtype=args.params_dtype)
345
346
347
348
349
350

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
351
                                    group=mpu.get_pipeline_model_parallel_group())
352
353
354
355
356

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
357
    """Backward step."""
Mohammad's avatar
Mohammad committed
358
359
    args = get_args()
    timers = get_timers()
360

361
362
363
364
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

365
    # Backward pass.
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    if args.fp16:
        optimizer.backward(output_tensor, update_master_grads=False,
                           output_tensor_grad=output_tensor_grad)
    else:
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


380
381
382
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
383
384
    args = get_args()

385
    if not mpu.is_pipeline_first_stage():
386
        timers('forward-recv').start()
387
388
389
390
391
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
392
        timers('forward-recv').stop()
393
394
395
396
    else:
        input_tensor = None

    # Forward model for one step.
397
    timers('forward-compute').start()
398
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
399
    timers('forward-compute').stop()
400
401
402

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
403
        output_tensor = loss / get_num_microbatches()
404
405
        losses_reduced.append(loss_reduced)
    else:
406
        timers('forward-send').start()
407
408
409
410
411
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
412
        timers('forward-send').stop()
413
414
415
416
417
418
419
420
421
422
423
424

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
425
        timers('backward-recv').start()
426
427
428
429
430
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
431
        timers('backward-recv').stop()
432
433

    # Backward pass for one step.
434
    timers('backward-compute').start()
435
436
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
437
    timers('backward-compute').stop()
438
439

    if not mpu.is_pipeline_first_stage():
440
        timers('backward-send').start()
441
442
443
444
445
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
446
        timers('backward-send').stop()
447
448


449
450
451
452
453
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
454
455
    args = get_args()

456
457
458
459
460
461
462
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
463
        output_tensor = loss / get_num_microbatches()
464
465
466
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
Deepak Narayanan's avatar
Deepak Narayanan committed
467
        timers('forward-send-backward-recv').start()
468
469
470
471
472
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
Deepak Narayanan's avatar
Deepak Narayanan committed
473
        timers('forward-send-backward-recv').stop()
474
475
476
477
478
479
480
481
482
483
484
485
486
487

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
Deepak Narayanan's avatar
Deepak Narayanan committed
488
        timers('backward-send-forward-recv').start()
489
490
491
492
493
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
Deepak Narayanan's avatar
Deepak Narayanan committed
494
        timers('backward-send-forward-recv').stop()
495
496
497
498
499
500
    else:
        input_tensor = None

    return input_tensor


501
502
503
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers):
    """Run forward and backward passes without inter-stage communication."""
504
505
    args = get_args()

506
    losses_reduced = []
mohammad's avatar
mohammad committed
507
    for i in range(get_num_microbatches()):
508
509
        timers('forward-compute').start()
        loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
mohammad's avatar
mohammad committed
510
        output_tensor = loss / get_num_microbatches()
511
512
513
514
515
516
517
518
519
520
        losses_reduced.append(loss_reduced)
        timers('forward-compute').stop()

        timers('backward-compute').start()
        output_tensor_grad = None
        backward_step(optimizer, model, input_tensor=None,
                      output_tensor=output_tensor, output_tensor_grad=None)
        timers('backward-compute').stop()

    return losses_reduced
521

522
523
524
525
526
527
528

def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    args = get_args()

    # Compute number of warmup microbatches.
mohammad's avatar
mohammad committed
529
    num_microbatches = get_num_microbatches()
530
531
532
533
534
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
535
536
537
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches
538
539
540
541
542

    input_tensors = []
    output_tensors = []
    losses_reduced = []

543
544
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
545
546
547
548
        forward_step_with_communication(
            forward_step_func, data_iterator, model,
            input_tensors, output_tensors,
            losses_reduced, timers)
549

550
    # Before running 1F1B, need to receive first forward tensor.
551
552
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
553
    if num_microbatches_remaining > 0:
554
555
556
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
557
            timers('forward-recv').start()
558
559
560
561
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)
562
            timers('forward-recv').stop()
563
564

    # Run 1F1B.
565
566
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))
567
568
569
570
571
572
573
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

574
575
    # Run cooldown backward passes.
    for i in range(num_warmup_microbatches):
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        backward_step_with_communication(
            optimizer, model, input_tensors, output_tensors, timers)

    return losses_reduced


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
    if args.fp16:
        optimizer.zero_grad(set_grads_to_None=True)
    else:
        optimizer.zero_grad()

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
    else:
        losses_reduced = forward_backward_no_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
600
601
602

    # All-reduce if needed.
    if args.DDP_impl == 'local':
603
        timers('backward-params-all-reduce').start()
604
605
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
606
        timers('backward-params-all-reduce').stop()
607

608
609
610
611
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
612
    timers('backward-embedding-all-reduce').start()
613
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
614
            mpu.get_pipeline_model_parallel_world_size() > 1:
615
616
617
618
        unwrapped_model = model
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
            unwrapped_model = unwrapped_model.module

619
620
621
622
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
623
    timers('backward-embedding-all-reduce').stop()
624

625
626
627
628
629
630
    # Update master gradients.
    timers('backward-master-grad').start()
    if args.fp16:
        optimizer.update_master_grads()
    timers('backward-master-grad').stop()

631
    # Clipping gradients helps prevent the exploding gradient.
632
    timers('backward-clip-grad').start()
633
    if args.clip_grad > 0.:
634
        if not args.fp16:
635
636
637
638
639
640
641
642
            named_parameters = model.named_parameters()
            parameters = []
            parameter_names = []
            for parameter_name, parameter in model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
643
644
        else:
            optimizer.clip_master_grads(args.clip_grad)
645
    timers('backward-clip-grad').stop()
646
647
648
649
650
651
652
653
654

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
655
656
657
658
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
659
660
661
    else:
        skipped_iter = 1

662
    if mpu.is_pipeline_last_stage():
663
664
665
666
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
667
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
668
669
        return loss_reduced, skipped_iter
    return {}, skipped_iter
670
671


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
672
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
673
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
674
675
676
677
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
678

mohammad's avatar
mohammad committed
679
680
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
681
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
682
683
684
685
686
687
688
689
690
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
691
692
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
693
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
694
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
695
    for key in loss_dict:
mohammad's avatar
mohammad committed
696
        if not skipped_iter:
697
698
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
699
700
701
702
703
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
704
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
705
706
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
707
708
709

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
710

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
711
712
713
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
714
715
716
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
717
    add_to_logging('forward-send-backward-recv')
718
719
720
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
721
    add_to_logging('backward-send-forward-recv')
722
    add_to_logging('backward-master-grad')
723
    add_to_logging('backward-params-all-reduce')
724
    add_to_logging('backward-embedding-all-reduce')
725
    add_to_logging('backward-clip-grad')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
726
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
727
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
728

mohammad's avatar
mohammad committed
729
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
730
731
732
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
733
734
735
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
736
    # Tensorboard values.
mohammad's avatar
mohammad committed
737
738
739
    if writer and is_last_rank():
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate,
740
                          args.consumed_train_samples)
mohammad's avatar
mohammad committed
741
742
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size,
743
                          args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
744
        for key in loss_dict:
mohammad's avatar
mohammad committed
745
746
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
747
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
748
        if args.fp16:
mohammad's avatar
mohammad committed
749
750
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
751
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
752
        timers.write(timers_to_log, writer, iteration,
mohammad's avatar
mohammad committed
753
                     normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
754
755
756

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
mohammad's avatar
mohammad committed
757
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
758
        if writer and torch.distributed.get_rank() == 0:
mohammad's avatar
mohammad committed
759
760
            writer.add_scalar('iteration-time',
                              elapsed_time_per_iteration, iteration)
761
762
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
763
        log_string += ' consumed samples: {:12d} |'.format(
764
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
765
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
766
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
767
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
768
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
769
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
770
771
772
773
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
774
775
776
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
777
778
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
779
780
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
781
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
782
783
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
784
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
785
        total_loss_dict[nan_iters_key] = 0
786
        print_rank_last(log_string)
787
788
789
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
790
791
792
793
794
795
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


796
797
798
799
800
801
802
803
804
805
806
807
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
    timers('save checkpoint').start()
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
    timers('save checkpoint').stop()
    timers.log(['save checkpoint'])


808
def train(forward_step_func, model, optimizer, lr_scheduler,
809
          train_data_iterator, valid_data_iterator):
810
    """Train the model function."""
Mohammad's avatar
Mohammad committed
811
812
    args = get_args()
    timers = get_timers()
813
814
815
816
817
818
819
820
821
822
823

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
824
    print_datetime('before the start of training step')
825
826
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
827
        update_num_microbatches(args.consumed_train_samples)
828
829
830
831
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
832
                                             lr_scheduler)
833
        iteration += 1
834
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
835
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
836
                                       get_num_microbatches()
837
838

        # Logging.
Mohammad's avatar
Mohammad committed
839
840
841
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
842
843
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
844
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
845
                                          report_memory_flag, skipped_iter)
846
847

        # Autoresume
848
849
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
850
            check_adlr_autoresume_termination(iteration, model, optimizer,
851
                                              lr_scheduler)
852
853

        # Checkpointing
854
        saved_checkpoint = False
855
856
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
857
858
859
860
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

861
862
863
864
865
866

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
867
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
868
                                       iteration, False)
869

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
886
        if args.exit_interval and iteration % args.exit_interval == 0:
887
888
889
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
890
            torch.distributed.barrier()
891
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
892
            sys.exit()
893

894

mohammad's avatar
mohammad committed
895
    return iteration
896
897


Mohammad's avatar
Mohammad committed
898
def evaluate(forward_step_func, data_iterator, model, verbose=False):
899
    """Evaluation."""
Mohammad's avatar
Mohammad committed
900
    args = get_args()
901
902
903
904
905
906
907
908
909
910
911
912
913

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
914

mohammad's avatar
mohammad committed
915
            for _ in range(get_num_microbatches()):
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
                if not mpu.is_pipeline_first_stage():
                    input_tensor, _ = communicate(
                        tensor_send_next=None,
                        tensor_send_prev=None,
                        recv_forward=True,
                        recv_backward=False)
                else:
                    input_tensor = None

                # Forward evaluation.
                output_tensor = forward_step_func(data_iterator, model, input_tensor)

                if mpu.is_pipeline_last_stage():
                    _, loss_dict = output_tensor
                    # Reduce across processes.
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                else:
                    communicate(
                        tensor_send_next=output_tensor,
                        tensor_send_prev=None,
                        recv_forward=False,
                        recv_backward=False)
940

941
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
942
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
943
                                           * get_num_microbatches()
944
945
946
947
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
948
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
949
950
951
952
953

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
954
                               iteration, verbose=False):
955
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
956
957
958
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
959
960
961
962
963
964
965
966
967
968
969
970
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
971
972
973
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
974
975


976
977
978
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
979
    args = get_args()
980

981
982
983
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
984
985
986

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
987
988
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
989
        args.consumed_train_samples = args.iteration * args.global_batch_size
990
    if args.iteration > 0 and args.consumed_valid_samples == 0:
991
992
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
993
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
994
            args.eval_iters * args.global_batch_size
995

996
    # Data loader only on rank 0 of each model parallel group.
997
    if mpu.get_tensor_model_parallel_rank() == 0:
998
999

        # Number of train/valid/test samples.
1000
1001
1002
1003
1004
1005
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
1006
        test_iters = args.eval_iters
1007
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
1008
1009
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
1020
1021
1022
1023
1024
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
1038
1039
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
1040
1041
1042
1043
1044
1045
1046
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
1047
1048
1049
    else:
        train_data_iterator = None

1050
1051
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
1052
    else:
1053
        valid_data_iterator = None
1054

1055
1056
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
1057
1058
1059
    else:
        test_data_iterator = None

1060
    return train_data_iterator, valid_data_iterator, test_data_iterator