training.py 40.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
28
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
29
from megatron import get_args
Mohammad's avatar
Mohammad committed
30
31
from megatron import get_timers
from megatron import get_tensorboard_writer
32
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
33
34
from megatron import get_num_microbatches
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
41
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
42
from megatron.initialize import initialize_megatron
43
44
45
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
46
from megatron.model.realm_model import ICTBertModel
47
from megatron.utils import check_adlr_autoresume_termination
48
from megatron.data.data_loaders import build_pretraining_data_loader
49
from megatron.utils import report_memory
50
51


52
53
54
55
56
57
58
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


59
def pretrain(train_valid_test_dataset_provider, model_provider,
60
             forward_step_func, extra_args_provider=None, args_defaults={}):
61
62
63
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
64
65
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
66
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
67
        4) train the modle using the forward_step_func.
68
69

    Arguments:
70
71
72
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
73
74
75
76
77
78
79
80
81
82
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
83
84
    """

85
    # Initalize and get arguments, timers, and Tensorboard writer.
86
87
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
88

89
90
91
92
93
94
95
96
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
97
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
98
99
100
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

101
    args = get_args()
Mohammad's avatar
Mohammad committed
102
    timers = get_timers()
103
104

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
105
106
107
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
108
109
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
110
111

    # Data stuff.
112
113
114
115
116
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
mshoeybi's avatar
mshoeybi committed
117
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
118
119
120

    # Print setup timing.
    print_rank_0('done with setups ...')
121
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
122
    print_rank_0('training ...')
123
124

    iteration = 0
125
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
126
127
128
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
129
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
130

131
132
133
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
134
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
135
                                   iteration, False)
136
137

    if args.save and iteration != 0:
138
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
139
140
141
142
143
144

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
145
                                   0, True)
146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
163
164
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
165
166
            iterations += 1
        # Reset
167
        update_num_microbatches(0, consistency_check=False)
168
169
170
171
172
173
174
175
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

176

Mohammad's avatar
Mohammad committed
177
def get_model(model_provider_func):
178
    """Build the model."""
Mohammad's avatar
Mohammad committed
179
    args = get_args()
180
181

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
182
    model = model_provider_func()
183
184
185

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
186
        print(' > number of parameters on (tensor, pipeline) '
187
              'model parallel rank ({}, {}): {}'.format(
188
189
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
190
191
192
193
194
195
196
197
198
199
200
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
201
202
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
203
204
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
205
        model = LocalDDP(model)
206
207
        return model

208
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
209
                              'Exiting.'.format(args.DDP_impl))
210
211


Mohammad's avatar
Mohammad committed
212
def get_optimizer(model):
213
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
214
    args = get_args()
215
216

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
217
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
218
219
220
221
222
223
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
224
225
            if not hasattr(param, 'tensor_model_parallel'):
                param.tensor_model_parallel = False
226
227

    # Use Adam.
228
229
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
230
231
232
233
234
235
236
237

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
238
                                       'min_scale': args.min_scale,
239
240
241
242
243
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
244
def get_learning_rate_scheduler(optimizer):
245
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
246
    args = get_args()
247

248
249
250
251
252
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
253
254
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
255
256
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
257
258
259
260
261
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
262
        update_train_iters(args)
263
264
265
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
266
267
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
268
269
        else:
            warmup_steps = args.lr_warmup_samples
270
    else:
271
272
273
        raise Exception(
            'either train-iters or train-samples should be provided.')

274
275
    lr_scheduler = AnnealingLR(
        optimizer,
276
        max_lr=args.lr,
277
        min_lr=args.min_lr,
278
279
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
280
        decay_style=args.lr_decay_style,
281
282
283
284
285
286
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
287
def setup_model_and_optimizer(model_provider_func):
288
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
289
    args = get_args()
290

Mohammad's avatar
Mohammad committed
291
292
293
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
294
295

    if args.load is not None:
296
297
298
299
300
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
        timers('load checkpoint').start()
301
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
302
303
304
        torch.distributed.barrier()
        timers('load checkpoint').stop()
        timers.log(['load checkpoint'])
305
306
307
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
308
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
309
310
311
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
312
313
314
315
316
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

317
318
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
319
        print("Initializing ICT from pretrained BERT model", flush=True)
320
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
321

322
323
324
    return model, optimizer, lr_scheduler


325
326
327
328
329
330
331
332
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
333
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
334
335
336
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
337
338
                                       device=torch.cuda.current_device(),
                                       dtype=args.params_dtype)
339
340
341
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
342
343
                                       device=torch.cuda.current_device(),
                                       dtype=args.params_dtype)
344
345
346
347
348
349

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
350
                                    group=mpu.get_pipeline_model_parallel_group())
351
352
353
354
355

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
356
    """Backward step."""
Mohammad's avatar
Mohammad committed
357
358
    args = get_args()
    timers = get_timers()
359

360
361
362
363
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

364
    # Backward pass.
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    if args.fp16:
        optimizer.backward(output_tensor, update_master_grads=False,
                           output_tensor_grad=output_tensor_grad)
    else:
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


379
380
381
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
382
383
    args = get_args()

384
    if not mpu.is_pipeline_first_stage():
385
        timers('forward-recv').start()
386
387
388
389
390
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
391
        timers('forward-recv').stop()
392
393
394
395
    else:
        input_tensor = None

    # Forward model for one step.
396
    timers('forward-compute').start()
397
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
398
    timers('forward-compute').stop()
399
400
401

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
402
        output_tensor = loss / get_num_microbatches()
403
404
        losses_reduced.append(loss_reduced)
    else:
405
        timers('forward-send').start()
406
407
408
409
410
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
411
        timers('forward-send').stop()
412
413
414
415
416
417
418
419
420
421
422
423

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
424
        timers('backward-recv').start()
425
426
427
428
429
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
430
        timers('backward-recv').stop()
431
432

    # Backward pass for one step.
433
    timers('backward-compute').start()
434
435
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
436
    timers('backward-compute').stop()
437
438

    if not mpu.is_pipeline_first_stage():
439
        timers('backward-send').start()
440
441
442
443
444
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
445
        timers('backward-send').stop()
446
447


448
449
450
451
452
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
453
454
    args = get_args()

455
456
457
458
459
460
461
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
462
        output_tensor = loss / get_num_microbatches()
463
464
465
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
Deepak Narayanan's avatar
Deepak Narayanan committed
466
        timers('forward-send-backward-recv').start()
467
468
469
470
471
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
Deepak Narayanan's avatar
Deepak Narayanan committed
472
        timers('forward-send-backward-recv').stop()
473
474
475
476
477
478
479
480
481
482
483
484
485
486

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
Deepak Narayanan's avatar
Deepak Narayanan committed
487
        timers('backward-send-forward-recv').start()
488
489
490
491
492
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
Deepak Narayanan's avatar
Deepak Narayanan committed
493
        timers('backward-send-forward-recv').stop()
494
495
496
497
498
499
    else:
        input_tensor = None

    return input_tensor


500
501
502
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers):
    """Run forward and backward passes without inter-stage communication."""
503
504
    args = get_args()

505
    losses_reduced = []
mohammad's avatar
mohammad committed
506
    for i in range(get_num_microbatches()):
507
508
        timers('forward-compute').start()
        loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
mohammad's avatar
mohammad committed
509
        output_tensor = loss / get_num_microbatches()
510
511
512
513
514
515
516
517
518
519
        losses_reduced.append(loss_reduced)
        timers('forward-compute').stop()

        timers('backward-compute').start()
        output_tensor_grad = None
        backward_step(optimizer, model, input_tensor=None,
                      output_tensor=output_tensor, output_tensor_grad=None)
        timers('backward-compute').stop()

    return losses_reduced
520

521
522
523
524
525
526
527

def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    args = get_args()

    # Compute number of warmup microbatches.
mohammad's avatar
mohammad committed
528
    num_microbatches = get_num_microbatches()
529
530
531
532
533
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
534
535
536
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches
537
538
539
540
541

    input_tensors = []
    output_tensors = []
    losses_reduced = []

542
543
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
544
545
546
547
        forward_step_with_communication(
            forward_step_func, data_iterator, model,
            input_tensors, output_tensors,
            losses_reduced, timers)
548

549
    # Before running 1F1B, need to receive first forward tensor.
550
551
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
552
    if num_microbatches_remaining > 0:
553
554
555
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
556
            timers('forward-recv').start()
557
558
559
560
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)
561
            timers('forward-recv').stop()
562
563

    # Run 1F1B.
564
565
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))
566
567
568
569
570
571
572
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

573
574
    # Run cooldown backward passes.
    for i in range(num_warmup_microbatches):
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        backward_step_with_communication(
            optimizer, model, input_tensors, output_tensors, timers)

    return losses_reduced


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
    if args.fp16:
        optimizer.zero_grad(set_grads_to_None=True)
    else:
        optimizer.zero_grad()

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
    else:
        losses_reduced = forward_backward_no_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
599
600
601

    # All-reduce if needed.
    if args.DDP_impl == 'local':
602
        timers('backward-params-all-reduce').start()
603
604
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
605
        timers('backward-params-all-reduce').stop()
606

607
608
609
610
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
611
    timers('backward-embedding-all-reduce').start()
612
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
613
            mpu.get_pipeline_model_parallel_world_size() > 1:
614
615
616
617
        unwrapped_model = model
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
            unwrapped_model = unwrapped_model.module

618
619
620
621
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
622
    timers('backward-embedding-all-reduce').stop()
623

624
625
626
627
628
629
    # Update master gradients.
    timers('backward-master-grad').start()
    if args.fp16:
        optimizer.update_master_grads()
    timers('backward-master-grad').stop()

630
    # Clipping gradients helps prevent the exploding gradient.
631
    timers('backward-clip-grad').start()
632
    if args.clip_grad > 0.:
633
        if not args.fp16:
634
635
636
637
638
639
640
641
            named_parameters = model.named_parameters()
            parameters = []
            parameter_names = []
            for parameter_name, parameter in model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
642
643
        else:
            optimizer.clip_master_grads(args.clip_grad)
644
    timers('backward-clip-grad').stop()
645
646
647
648
649
650
651
652
653

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
654
655
656
657
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
658
659
660
    else:
        skipped_iter = 1

661
    if mpu.is_pipeline_last_stage():
662
663
664
665
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
666
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
667
668
        return loss_reduced, skipped_iter
    return {}, skipped_iter
669
670


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
671
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
672
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
673
674
675
676
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
677
678

    # Update losses.
mohammad's avatar
mohammad committed
679
680
681
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
682
    got_nan_key = 'got nan'
mohammad's avatar
mohammad committed
683
684

    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
685
    for key in loss_dict:
mohammad's avatar
mohammad committed
686
        if not skipped_iter:
687
688
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
689
690
691
692
693
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
694
695
696
697
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
698
699
700

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
701

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
702
703
704
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
705
706
707
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
708
    add_to_logging('forward-send-backward-recv')
709
710
711
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
712
    add_to_logging('backward-send-forward-recv')
713
    add_to_logging('backward-master-grad')
714
    add_to_logging('backward-params-all-reduce')
715
    add_to_logging('backward-embedding-all-reduce')
716
    add_to_logging('backward-clip-grad')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
717
718
719
    add_to_logging('optimizer')
    add_to_logging('batch generator')

mshoeybi's avatar
mshoeybi committed
720
721
722
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
723
724
    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
725
726
727
728
729
730
        writer.add_scalar('learning_rate-iterations', learning_rate, iteration)
        writer.add_scalar('learning_rate-samples', learning_rate,
                          args.consumed_train_samples)
        writer.add_scalar('batch_size-iterations', batch_size, iteration)
        writer.add_scalar('batch_size-samples', batch_size,
                          args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
731
        for key in loss_dict:
732
733
734
            writer.add_scalar(key, loss_dict[key] + '-iterations', iteration)
            writer.add_scalar(key, loss_dict[key] + '-samples',
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
735
        if args.fp16:
736
737
738
            writer.add_scalar('loss_scale-iterations', loss_scale, iteration)
            writer.add_scalar('loss_scale-samples', loss_scale,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
739
740
741
742
743
744
745
746
747
748
749
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
750
751
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
752
        log_string += ' consumed samples: {:12d} |'.format(
753
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
754
755
756
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mshoeybi's avatar
mshoeybi committed
757
        log_string += ' global batch size: {:6d} |'.format(batch_size)
mohammad's avatar
mohammad committed
758
759
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
760
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
761
            if key not in [skipped_iters_key, got_nan_key]:
mohammad's avatar
mohammad committed
762
                avg = total_loss_dict[key].item() / float(num_iterations)
763
764
765
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
766
767
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
768
769
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
770
771
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
mohammad's avatar
mohammad committed
772
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
773
        total_loss_dict[got_nan_key] = 0
774
        print_rank_last(log_string)
775
776
777
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
778
779
780
781
782
783
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


784
785
786
787
788
789
790
791
792
793
794
795
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
    timers('save checkpoint').start()
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
    timers('save checkpoint').stop()
    timers.log(['save checkpoint'])


796
def train(forward_step_func, model, optimizer, lr_scheduler,
797
          train_data_iterator, valid_data_iterator):
798
    """Train the model function."""
Mohammad's avatar
Mohammad committed
799
800
    args = get_args()
    timers = get_timers()
801
802
803
804
805
806
807
808
809
810
811

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
812
    print_datetime('before the start of training step')
813
814
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
815
        update_num_microbatches(args.consumed_train_samples)
816
817
818
819
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
820
                                             lr_scheduler)
821
        iteration += 1
822
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
823
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
824
                                       get_num_microbatches()
825
826

        # Logging.
Mohammad's avatar
Mohammad committed
827
828
829
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
830
831
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
832
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
833
                                          report_memory_flag, skipped_iter)
834
835

        # Autoresume
836
837
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
838
            check_adlr_autoresume_termination(iteration, model, optimizer,
839
                                              lr_scheduler)
840
841

        # Checkpointing
842
        saved_checkpoint = False
843
844
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
845
846
847
848
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

849
850
851
852
853
854

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
855
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
856
                                       iteration, False)
857

858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
874
        if args.exit_interval and iteration % args.exit_interval == 0:
875
876
877
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
878
            torch.distributed.barrier()
879
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
880
            sys.exit()
881

882

mohammad's avatar
mohammad committed
883
    return iteration
884
885


Mohammad's avatar
Mohammad committed
886
def evaluate(forward_step_func, data_iterator, model, verbose=False):
887
    """Evaluation."""
Mohammad's avatar
Mohammad committed
888
    args = get_args()
889
890
891
892
893
894
895
896
897
898
899
900
901

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
902

mohammad's avatar
mohammad committed
903
            for _ in range(get_num_microbatches()):
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
                if not mpu.is_pipeline_first_stage():
                    input_tensor, _ = communicate(
                        tensor_send_next=None,
                        tensor_send_prev=None,
                        recv_forward=True,
                        recv_backward=False)
                else:
                    input_tensor = None

                # Forward evaluation.
                output_tensor = forward_step_func(data_iterator, model, input_tensor)

                if mpu.is_pipeline_last_stage():
                    _, loss_dict = output_tensor
                    # Reduce across processes.
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                else:
                    communicate(
                        tensor_send_next=output_tensor,
                        tensor_send_prev=None,
                        recv_forward=False,
                        recv_backward=False)
928

929
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
930
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
931
                                           * get_num_microbatches()
932
933
934
935
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
936
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
937
938
939
940
941

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
942
                               iteration, verbose=False):
943
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
944
945
946
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
947
948
949
950
951
952
953
954
955
956
957
958
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
959
960
961
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
962
963


964
965
966
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
967
    args = get_args()
968

969
970
971
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
972
973
974

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
975
976
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
977
        args.consumed_train_samples = args.iteration * args.global_batch_size
978
    if args.iteration > 0 and args.consumed_valid_samples == 0:
979
980
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
981
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
982
            args.eval_iters * args.global_batch_size
983

984
    # Data loader only on rank 0 of each model parallel group.
985
    if mpu.get_tensor_model_parallel_rank() == 0:
986
987

        # Number of train/valid/test samples.
988
989
990
991
992
993
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
994
        test_iters = args.eval_iters
995
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
996
997
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
1008
1009
1010
1011
1012
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
1026
1027
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
1028
1029
1030
1031
1032
1033
1034
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
1035
1036
1037
    else:
        train_data_iterator = None

1038
1039
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
1040
    else:
1041
        valid_data_iterator = None
1042

1043
1044
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
1045
1046
1047
    else:
        test_data_iterator = None

1048
    return train_data_iterator, valid_data_iterator, test_data_iterator