training.py 42 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
from megatron.model import FP16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
mohammad's avatar
mohammad committed
42

Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
50
from megatron.utils import calc_params_l2_norm
51
from megatron.utils import report_memory
52
53


54
55
56
57
58
59
60
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


61
62
63
64
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
             args_defaults={}):
66
67
68
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
69
70
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
71
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
72
        4) train the modle using the forward_step_func.
73
74

    Arguments:
75
76
77
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
78
79
80
81
82
83
84
85
86
87
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
88
89
    """

90
    # Initalize and get arguments, timers, and Tensorboard writer.
91
92
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
93

94
95
96
97
98
99
100
101
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
102
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
103
104
105
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

106
    args = get_args()
Mohammad's avatar
Mohammad committed
107
    timers = get_timers()
108
109

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
110
111
112
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
113
114
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
115
116

    # Data stuff.
117
118
119
120
121
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
mshoeybi's avatar
mshoeybi committed
122
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
123
124
125

    # Print setup timing.
    print_rank_0('done with setups ...')
126
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
127
    print_rank_0('training ...')
128
129

    iteration = 0
130
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
131
132
133
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
134
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
135

136
137
138
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
139
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
140
                                   iteration, False)
141
142

    if args.save and iteration != 0:
143
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
144
145
146
147
148
149

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
150
                                   0, True)
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
168
169
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
170
171
            iterations += 1
        # Reset
172
        update_num_microbatches(0, consistency_check=False)
173
174
175
176
177
178
179
180
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

181

Mohammad's avatar
Mohammad committed
182
def get_model(model_provider_func):
183
    """Build the model."""
Mohammad's avatar
Mohammad committed
184
    args = get_args()
185
186

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
187
    model = model_provider_func()
188

189
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
190
191
192
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
193
194
195
    for param in model.parameters():
        mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)

196
197
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
198
        print(' > number of parameters on (tensor, pipeline) '
199
              'model parallel rank ({}, {}): {}'.format(
200
201
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
202
203
204
205
206
207
208
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
209
        model = FP16Module(model)
210
211
212

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
213
214
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
215
216
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
217
        model = LocalDDP(model)
218
219
        return model

220
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
221
                              'Exiting.'.format(args.DDP_impl))
222
223


Mohammad's avatar
Mohammad committed
224
def get_learning_rate_scheduler(optimizer):
225
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
226
    args = get_args()
227

228
229
230
231
232
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
233
234
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
235
236
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
237
238
239
240
241
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
242
        update_train_iters(args)
243
244
245
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
246
247
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
248
249
        else:
            warmup_steps = args.lr_warmup_samples
250
    else:
251
252
253
        raise Exception(
            'either train-iters or train-samples should be provided.')

254
255
    lr_scheduler = AnnealingLR(
        optimizer,
256
        max_lr=args.lr,
257
        min_lr=args.min_lr,
258
259
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
260
        decay_style=args.lr_decay_style,
261
262
263
264
265
266
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
267
def setup_model_and_optimizer(model_provider_func):
268
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
269
    args = get_args()
270

Mohammad's avatar
Mohammad committed
271
    model = get_model(model_provider_func)
272
273

    unwrapped_model = model
274
    while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
275
276
277
        unwrapped_model = unwrapped_model.module
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
278
    lr_scheduler = get_learning_rate_scheduler(optimizer)
279
280

    if args.load is not None:
281
282
283
284
285
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
        timers('load checkpoint').start()
286
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
287
288
289
        torch.distributed.barrier()
        timers('load checkpoint').stop()
        timers.log(['load checkpoint'])
290
291
292
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
293
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
294
295
296
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
297
298
299
300
301
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

302
303
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
304
        print("Initializing ICT from pretrained BERT model", flush=True)
305
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
306

307
308
309
    return model, optimizer, lr_scheduler


310
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
311
    """Communicate tensors between stages."""
312
313
314
315
316
317
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
318
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
319
320
321
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
322
323
324
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
325
                                       device=torch.cuda.current_device(),
326
                                       dtype=dtype)
327
328
329
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
330
                                       device=torch.cuda.current_device(),
331
                                       dtype=dtype)
332
333

    # Send tensors in both the forward and backward directions as appropriate.
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    ops = []
    if tensor_send_prev is not None:
        send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
                                               mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(send_prev_op)
    if tensor_recv_prev is not None:
        recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
                                               mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(recv_prev_op)
    if tensor_send_next is not None:
        send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
                                               mpu.get_pipeline_model_parallel_next_rank())
        ops.append(send_next_op)
    if tensor_recv_next is not None:
        recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
                                               mpu.get_pipeline_model_parallel_next_rank())
        ops.append(recv_next_op)
    reqs = torch.distributed.batch_isend_irecv(ops)
    for req in reqs:
        req.wait()
354
355
356
357
358

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
359
    """Backward step."""
Mohammad's avatar
Mohammad committed
360
361
    args = get_args()
    timers = get_timers()
362

363
364
365
366
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

367
    # Backward pass.
mohammad's avatar
mohammad committed
368
369
370
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
371
372
373
374
375
376
377
378
379

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


380
381
382
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
383
384
    args = get_args()

385
    if not mpu.is_pipeline_first_stage():
386
        timers('forward-recv').start()
387
388
389
390
391
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
392
        timers('forward-recv').stop()
393
394
395
396
    else:
        input_tensor = None

    # Forward model for one step.
397
    timers('forward-compute').start()
398
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
399
    timers('forward-compute').stop()
400
401
402

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
403
        output_tensor = loss / get_num_microbatches()
404
405
        losses_reduced.append(loss_reduced)
    else:
406
        timers('forward-send').start()
407
408
409
410
411
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
412
        timers('forward-send').stop()
413
414
415
416
417
418
419
420
421
422
423
424

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
425
        timers('backward-recv').start()
426
427
428
429
430
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
431
        timers('backward-recv').stop()
432
433

    # Backward pass for one step.
434
    timers('backward-compute').start()
435
436
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
437
    timers('backward-compute').stop()
438
439

    if not mpu.is_pipeline_first_stage():
440
        timers('backward-send').start()
441
442
443
444
445
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
446
        timers('backward-send').stop()
447
448


449
450
451
452
453
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
454
455
    args = get_args()

456
457
458
459
460
461
462
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
463
        output_tensor = loss / get_num_microbatches()
464
465
466
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
Deepak Narayanan's avatar
Deepak Narayanan committed
467
        timers('forward-send-backward-recv').start()
468
469
470
471
472
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
Deepak Narayanan's avatar
Deepak Narayanan committed
473
        timers('forward-send-backward-recv').stop()
474
475
476
477
478
479
480
481
482
483
484
485
486
487

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
Deepak Narayanan's avatar
Deepak Narayanan committed
488
        timers('backward-send-forward-recv').start()
489
490
491
492
493
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
Deepak Narayanan's avatar
Deepak Narayanan committed
494
        timers('backward-send-forward-recv').stop()
495
496
497
498
499
500
    else:
        input_tensor = None

    return input_tensor


501
502
503
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers):
    """Run forward and backward passes without inter-stage communication."""
504
505
    args = get_args()

506
    losses_reduced = []
mohammad's avatar
mohammad committed
507
    for i in range(get_num_microbatches()):
508
509
        timers('forward-compute').start()
        loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
mohammad's avatar
mohammad committed
510
        output_tensor = loss / get_num_microbatches()
511
512
513
514
515
516
517
518
519
520
        losses_reduced.append(loss_reduced)
        timers('forward-compute').stop()

        timers('backward-compute').start()
        output_tensor_grad = None
        backward_step(optimizer, model, input_tensor=None,
                      output_tensor=output_tensor, output_tensor_grad=None)
        timers('backward-compute').stop()

    return losses_reduced
521

522
523
524
525
526
527
528

def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    args = get_args()

    # Compute number of warmup microbatches.
mohammad's avatar
mohammad committed
529
    num_microbatches = get_num_microbatches()
530
531
532
533
534
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
535
536
537
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches
538
539
540
541
542

    input_tensors = []
    output_tensors = []
    losses_reduced = []

543
544
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
545
546
547
548
        forward_step_with_communication(
            forward_step_func, data_iterator, model,
            input_tensors, output_tensors,
            losses_reduced, timers)
549

550
    # Before running 1F1B, need to receive first forward tensor.
551
552
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
553
    if num_microbatches_remaining > 0:
554
555
556
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
557
            timers('forward-recv').start()
558
559
560
561
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)
562
            timers('forward-recv').stop()
563
564

    # Run 1F1B.
565
566
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))
567
568
569
570
571
572
573
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

574
575
    # Run cooldown backward passes.
    for i in range(num_warmup_microbatches):
576
577
578
579
580
581
582
583
584
585
586
587
588
        backward_step_with_communication(
            optimizer, model, input_tensors, output_tensors, timers)

    return losses_reduced


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
mohammad's avatar
mohammad committed
589
    optimizer.zero_grad()
590
591
592
593
594
595
596

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
    else:
        losses_reduced = forward_backward_no_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
597
598
599

    # All-reduce if needed.
    if args.DDP_impl == 'local':
600
        timers('backward-params-all-reduce').start()
601
602
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
603
        timers('backward-params-all-reduce').stop()
604

605
606
607
608
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
609
    timers('backward-embedding-all-reduce').start()
610
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
611
            mpu.get_pipeline_model_parallel_world_size() > 1:
612
        unwrapped_model = model
613
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
614
615
            unwrapped_model = unwrapped_model.module

616
617
618
619
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
620
    timers('backward-embedding-all-reduce').stop()
621

622
623
    # Update parameters.
    timers('optimizer').start()
624
    update_successfull, grad_norm = optimizer.step()
625
626
627
    timers('optimizer').stop()

    # Update learning rate.
mohammad's avatar
mohammad committed
628
    if update_successfull:
629
630
631
632
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
633
        skipped_iter = 0
634
635
636
    else:
        skipped_iter = 1

637
    if mpu.is_pipeline_last_stage():
638
639
640
641
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
642
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
643
644
        return loss_reduced, skipped_iter, grad_norm
    return {}, skipped_iter, grad_norm
645
646


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
647
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
648
649
                 loss_scale, report_memory_flag, skipped_iter,
                 grad_norm, params_norm):
Mohammad's avatar
Mohammad committed
650
651
652
653
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
654

mohammad's avatar
mohammad committed
655
656
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
657
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
658
659
660
661
662
663
664
665
666
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
667
668
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
669
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
670
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
671
    for key in loss_dict:
mohammad's avatar
mohammad committed
672
        if not skipped_iter:
673
674
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
675
676
677
678
679
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
680
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
681
682
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
683
684
685

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
686

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
687
688
689
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
690
691
692
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
693
    add_to_logging('forward-send-backward-recv')
694
695
696
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
697
    add_to_logging('backward-send-forward-recv')
698
    add_to_logging('backward-params-all-reduce')
699
    add_to_logging('backward-embedding-all-reduce')
700
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
701
    add_to_logging('optimizer-unscale-and-check-inf')
702
703
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
704
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
705
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
706

mohammad's avatar
mohammad committed
707
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
708
709
710
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
711
712
713
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
714
    # Tensorboard values.
mohammad's avatar
mohammad committed
715
716
717
    if writer and is_last_rank():
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate,
718
                          args.consumed_train_samples)
mohammad's avatar
mohammad committed
719
720
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size,
721
                          args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
722
        for key in loss_dict:
mohammad's avatar
mohammad committed
723
724
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
725
                              args.consumed_train_samples)
726
727
728
        writer.add_scalar('loss-scale', loss_scale, iteration)
        writer.add_scalar('loss-scale vs samples', loss_scale,
                          args.consumed_train_samples)
729
730
731
732
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
733
734
735
736
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
737
        timers.write(timers_to_log, writer, iteration,
mohammad's avatar
mohammad committed
738
                     normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
739
740
741

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
mohammad's avatar
mohammad committed
742
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
743
        if writer and torch.distributed.get_rank() == 0:
mohammad's avatar
mohammad committed
744
745
            writer.add_scalar('iteration-time',
                              elapsed_time_per_iteration, iteration)
746
747
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
748
        log_string += ' consumed samples: {:12d} |'.format(
749
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
750
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
751
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
752
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
753
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
754
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
755
756
757
758
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
759
760
761
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
762
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
763
764
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
mohammad's avatar
mohammad committed
765
766
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
767
768
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
769
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
770
771
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
772
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
773
        total_loss_dict[nan_iters_key] = 0
774
        print_rank_last(log_string)
775
776
777
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
778
779
780
781
782
783
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


784
785
786
787
788
789
790
791
792
793
794
795
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
    timers('save checkpoint').start()
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
    timers('save checkpoint').stop()
    timers.log(['save checkpoint'])


796
def train(forward_step_func, model, optimizer, lr_scheduler,
797
          train_data_iterator, valid_data_iterator):
798
    """Train the model function."""
Mohammad's avatar
Mohammad committed
799
800
    args = get_args()
    timers = get_timers()
801

802
803
804
    # Write args to tensorboard
    write_args_to_tensorboard()

805
806
807
808
809
810
811
812
813
814
    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
815
    print_datetime('before the start of training step')
816
817
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
818
        update_num_microbatches(args.consumed_train_samples)
819
820
821
822
823
        loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
                                                        train_data_iterator,
                                                        model,
                                                        optimizer,
                                                        lr_scheduler)
824
        iteration += 1
825
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
826
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
827
                                       get_num_microbatches()
828
829

        # Logging.
830
        loss_scale = optimizer.get_loss_scale().item()
831
832
833
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
834
835
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
836
                                          iteration, loss_scale,
837
                                          report_memory_flag, skipped_iter,
mohammad's avatar
mohammad committed
838
                                          grad_norm, params_norm)
839
840

        # Autoresume
841
842
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
843
            check_adlr_autoresume_termination(iteration, model, optimizer,
844
                                              lr_scheduler)
845
846
847
848
849
850

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
851
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
852
                                       iteration, False)
853

854
855
856
857
858
859
860
861
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
878
        if args.exit_interval and iteration % args.exit_interval == 0:
879
880
881
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
882
            torch.distributed.barrier()
883
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
884
            sys.exit()
885

886

mohammad's avatar
mohammad committed
887
    return iteration
888
889


Mohammad's avatar
Mohammad committed
890
def evaluate(forward_step_func, data_iterator, model, verbose=False):
891
    """Evaluation."""
Mohammad's avatar
Mohammad committed
892
    args = get_args()
893
894
895
896
897
898
899
900
901
902
903
904
905

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
906

mohammad's avatar
mohammad committed
907
            for _ in range(get_num_microbatches()):
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
                if not mpu.is_pipeline_first_stage():
                    input_tensor, _ = communicate(
                        tensor_send_next=None,
                        tensor_send_prev=None,
                        recv_forward=True,
                        recv_backward=False)
                else:
                    input_tensor = None

                # Forward evaluation.
                output_tensor = forward_step_func(data_iterator, model, input_tensor)

                if mpu.is_pipeline_last_stage():
                    _, loss_dict = output_tensor
                    # Reduce across processes.
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                else:
                    communicate(
                        tensor_send_next=output_tensor,
                        tensor_send_prev=None,
                        recv_forward=False,
                        recv_backward=False)
932

933
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
934
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
935
                                           * get_num_microbatches()
936
937
938
939
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
940
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
941
942
943
944
945

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
946
                               iteration, verbose=False):
947
    """Helper function to evaluate and dump results on screen."""
948
    args = get_args()
Mohammad's avatar
Mohammad committed
949
950
951
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
952
953
954
955
956
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
957
958
        if writer and is_last_rank():
            writer.add_scalar('{} value-validation'.format(key),
959
960
                              total_loss_dict[key].item(),
                              iteration)
961
962
963
964
965
966
            writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration)
            writer.add_scalar('{} value-validation vs samples'.format(key),
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
            writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl,
                              args.consumed_train_samples)
967
968

    length = len(string) + 1
969
970
971
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
972
973


Vijay Korthikanti's avatar
Vijay Korthikanti committed
974
def cyclic_iter(iter):
975
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
976
        for x in iter:
977
978
            yield x

979
980
981
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
982
    args = get_args()
983

984
985
986
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
987
988
989

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
990
991
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
992
        args.consumed_train_samples = args.iteration * args.global_batch_size
993
    if args.iteration > 0 and args.consumed_valid_samples == 0:
994
995
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
996
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
997
            args.eval_iters * args.global_batch_size
998

999
    # Data loader only on rank 0 of each model parallel group.
1000
    if mpu.get_tensor_model_parallel_rank() == 0:
1001
1002

        # Number of train/valid/test samples.
1003
1004
1005
1006
1007
1008
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
1009
        test_iters = args.eval_iters
1010
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
1011
1012
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
1023
1024
1025
1026
1027
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
1041
1042
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
1043
1044
1045
1046
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1047

1048
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1049
1050
1051
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

1052
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1053
1054
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
1055
1056
1057
    else:
        train_data_iterator = None

1058
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1059
1060
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
1061
    else:
1062
        valid_data_iterator = None
1063

1064
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1065
1066
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
1067
1068
1069
    else:
        test_data_iterator = None

1070
    return train_data_iterator, valid_data_iterator, test_data_iterator