training.py 40.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
from megatron.model import FP16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
mohammad's avatar
mohammad committed
42

Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.data.data_loaders import build_pretraining_data_loader
50
from megatron.utils import report_memory
51
52


53
54
55
56
57
58
59
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


60
61
62
63
64
65
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
             args_defaults={},
             random_sample = False):
66
67
68
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
69
70
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
71
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
72
        4) train the modle using the forward_step_func.
73
74

    Arguments:
75
76
77
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
78
79
80
81
82
83
84
85
86
87
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
88
89
    """

90
    # Initalize and get arguments, timers, and Tensorboard writer.
91
92
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
93

94
95
96
97
98
99
100
101
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
102
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
103
104
105
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

106
    args = get_args()
Mohammad's avatar
Mohammad committed
107
    timers = get_timers()
108
109

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
110
111
112
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
113
114
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
115
116

    # Data stuff.
117
118
119
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
120
121
            train_valid_test_dataset_provider, 
            random_sample)
122
    timers('train/valid/test data iterators').stop()
mshoeybi's avatar
mshoeybi committed
123
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
124
125
126

    # Print setup timing.
    print_rank_0('done with setups ...')
127
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
128
    print_rank_0('training ...')
129
130

    iteration = 0
131
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
132
133
134
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
135
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
136

137
138
139
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
140
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
141
                                   iteration, False)
142
143

    if args.save and iteration != 0:
144
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
145
146
147
148
149
150

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
151
                                   0, True)
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
169
170
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
171
172
            iterations += 1
        # Reset
173
        update_num_microbatches(0, consistency_check=False)
174
175
176
177
178
179
180
181
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

182

Mohammad's avatar
Mohammad committed
183
def get_model(model_provider_func):
184
    """Build the model."""
Mohammad's avatar
Mohammad committed
185
    args = get_args()
186
187

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
188
    model = model_provider_func()
189

190
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
191
192
193
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
194
195
196
    for param in model.parameters():
        mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)

197
198
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
199
        print(' > number of parameters on (tensor, pipeline) '
200
              'model parallel rank ({}, {}): {}'.format(
201
202
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
203
204
205
206
207
208
209
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
210
        model = FP16Module(model)
211
212
213

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
214
215
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
216
217
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
218
        model = LocalDDP(model)
219
220
        return model

221
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
222
                              'Exiting.'.format(args.DDP_impl))
223
224


Mohammad's avatar
Mohammad committed
225
def get_learning_rate_scheduler(optimizer):
226
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
227
    args = get_args()
228

229
230
231
232
233
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
234
235
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
236
237
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
238
239
240
241
242
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
243
        update_train_iters(args)
244
245
246
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
247
248
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
249
250
        else:
            warmup_steps = args.lr_warmup_samples
251
    else:
252
253
254
        raise Exception(
            'either train-iters or train-samples should be provided.')

255
256
    lr_scheduler = AnnealingLR(
        optimizer,
257
        max_lr=args.lr,
258
        min_lr=args.min_lr,
259
260
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
261
        decay_style=args.lr_decay_style,
262
263
264
265
266
267
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
268
def setup_model_and_optimizer(model_provider_func):
269
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
270
    args = get_args()
271

Mohammad's avatar
Mohammad committed
272
    model = get_model(model_provider_func)
273
274

    unwrapped_model = model
275
    while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
276
277
278
        unwrapped_model = unwrapped_model.module
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
279
    lr_scheduler = get_learning_rate_scheduler(optimizer)
280
281

    if args.load is not None:
282
283
284
285
286
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
        timers('load checkpoint').start()
287
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
288
289
290
        torch.distributed.barrier()
        timers('load checkpoint').stop()
        timers.log(['load checkpoint'])
291
292
293
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
294
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
295
296
297
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
298
299
300
301
302
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

303
304
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
305
        print("Initializing ICT from pretrained BERT model", flush=True)
306
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
307

308
309
310
    return model, optimizer, lr_scheduler


311
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
312
    """Communicate tensors between stages."""
313
314
315
316
317
318
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
319
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
320
321
322
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
323
324
325
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
326
                                       device=torch.cuda.current_device(),
327
                                       dtype=dtype)
328
329
330
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
331
                                       device=torch.cuda.current_device(),
332
                                       dtype=dtype)
333
334

    # Send tensors in both the forward and backward directions as appropriate.
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    ops = []
    if tensor_send_prev is not None:
        send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
                                               mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(send_prev_op)
    if tensor_recv_prev is not None:
        recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
                                               mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(recv_prev_op)
    if tensor_send_next is not None:
        send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
                                               mpu.get_pipeline_model_parallel_next_rank())
        ops.append(send_next_op)
    if tensor_recv_next is not None:
        recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
                                               mpu.get_pipeline_model_parallel_next_rank())
        ops.append(recv_next_op)
    reqs = torch.distributed.batch_isend_irecv(ops)
    for req in reqs:
        req.wait()
355
356
357
358
359

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
360
    """Backward step."""
Mohammad's avatar
Mohammad committed
361
362
    args = get_args()
    timers = get_timers()
363

364
365
366
367
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

368
    # Backward pass.
mohammad's avatar
mohammad committed
369
370
371
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
372
373
374
375
376
377
378
379
380

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


381
382
383
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
384
385
    args = get_args()

386
    if not mpu.is_pipeline_first_stage():
387
        timers('forward-recv').start()
388
389
390
391
392
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
393
        timers('forward-recv').stop()
394
395
396
397
    else:
        input_tensor = None

    # Forward model for one step.
398
    timers('forward-compute').start()
399
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
400
    timers('forward-compute').stop()
401
402
403

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
404
        output_tensor = loss / get_num_microbatches()
405
406
        losses_reduced.append(loss_reduced)
    else:
407
        timers('forward-send').start()
408
409
410
411
412
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
413
        timers('forward-send').stop()
414
415
416
417
418
419
420
421
422
423
424
425

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
426
        timers('backward-recv').start()
427
428
429
430
431
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
432
        timers('backward-recv').stop()
433
434

    # Backward pass for one step.
435
    timers('backward-compute').start()
436
437
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
438
    timers('backward-compute').stop()
439
440

    if not mpu.is_pipeline_first_stage():
441
        timers('backward-send').start()
442
443
444
445
446
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
447
        timers('backward-send').stop()
448
449


450
451
452
453
454
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
455
456
    args = get_args()

457
458
459
460
461
462
463
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
464
        output_tensor = loss / get_num_microbatches()
465
466
467
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
Deepak Narayanan's avatar
Deepak Narayanan committed
468
        timers('forward-send-backward-recv').start()
469
470
471
472
473
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
Deepak Narayanan's avatar
Deepak Narayanan committed
474
        timers('forward-send-backward-recv').stop()
475
476
477
478
479
480
481
482
483
484
485
486
487
488

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
Deepak Narayanan's avatar
Deepak Narayanan committed
489
        timers('backward-send-forward-recv').start()
490
491
492
493
494
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
Deepak Narayanan's avatar
Deepak Narayanan committed
495
        timers('backward-send-forward-recv').stop()
496
497
498
499
500
501
    else:
        input_tensor = None

    return input_tensor


502
503
504
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers):
    """Run forward and backward passes without inter-stage communication."""
505
506
    args = get_args()

507
    losses_reduced = []
mohammad's avatar
mohammad committed
508
    for i in range(get_num_microbatches()):
509
510
        timers('forward-compute').start()
        loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
mohammad's avatar
mohammad committed
511
        output_tensor = loss / get_num_microbatches()
512
513
514
515
516
517
518
519
520
521
        losses_reduced.append(loss_reduced)
        timers('forward-compute').stop()

        timers('backward-compute').start()
        output_tensor_grad = None
        backward_step(optimizer, model, input_tensor=None,
                      output_tensor=output_tensor, output_tensor_grad=None)
        timers('backward-compute').stop()

    return losses_reduced
522

523
524
525
526
527
528
529

def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    args = get_args()

    # Compute number of warmup microbatches.
mohammad's avatar
mohammad committed
530
    num_microbatches = get_num_microbatches()
531
532
533
534
535
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
536
537
538
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches
539
540
541
542
543

    input_tensors = []
    output_tensors = []
    losses_reduced = []

544
545
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
546
547
548
549
        forward_step_with_communication(
            forward_step_func, data_iterator, model,
            input_tensors, output_tensors,
            losses_reduced, timers)
550

551
    # Before running 1F1B, need to receive first forward tensor.
552
553
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
554
    if num_microbatches_remaining > 0:
555
556
557
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
558
            timers('forward-recv').start()
559
560
561
562
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)
563
            timers('forward-recv').stop()
564
565

    # Run 1F1B.
566
567
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))
568
569
570
571
572
573
574
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

575
576
    # Run cooldown backward passes.
    for i in range(num_warmup_microbatches):
577
578
579
580
581
582
583
584
585
586
587
588
589
        backward_step_with_communication(
            optimizer, model, input_tensors, output_tensors, timers)

    return losses_reduced


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
mohammad's avatar
mohammad committed
590
    optimizer.zero_grad()
591
592
593
594
595
596
597

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
    else:
        losses_reduced = forward_backward_no_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
598
599
600

    # All-reduce if needed.
    if args.DDP_impl == 'local':
601
        timers('backward-params-all-reduce').start()
602
603
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
604
        timers('backward-params-all-reduce').stop()
605

606
607
608
609
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
610
    timers('backward-embedding-all-reduce').start()
611
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
612
            mpu.get_pipeline_model_parallel_world_size() > 1:
613
        unwrapped_model = model
614
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
615
616
            unwrapped_model = unwrapped_model.module

617
618
619
620
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
621
    timers('backward-embedding-all-reduce').stop()
622

623
624
    # Update parameters.
    timers('optimizer').start()
mohammad's avatar
mohammad committed
625
    update_successfull = optimizer.step()
626
627
628
    timers('optimizer').stop()

    # Update learning rate.
mohammad's avatar
mohammad committed
629
    if update_successfull:
630
631
632
633
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
634
        skipped_iter = 0
635
636
637
    else:
        skipped_iter = 1

638
    if mpu.is_pipeline_last_stage():
639
640
641
642
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
643
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
644
645
        return loss_reduced, skipped_iter
    return {}, skipped_iter
646
647


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
648
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
649
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
650
651
652
653
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
654

mohammad's avatar
mohammad committed
655
656
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
657
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
658
659
660
661
662
663
664
665
666
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
667
668
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
669
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
670
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
671
    for key in loss_dict:
mohammad's avatar
mohammad committed
672
        if not skipped_iter:
673
674
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
675
676
677
678
679
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
680
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
681
682
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
683
684
685

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
686

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
687
688
689
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
690
691
692
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
693
    add_to_logging('forward-send-backward-recv')
694
695
696
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
697
    add_to_logging('backward-send-forward-recv')
698
    add_to_logging('backward-params-all-reduce')
699
    add_to_logging('backward-embedding-all-reduce')
700
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
701
    add_to_logging('optimizer-unscale-and-check-inf')
702
703
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
704
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
705
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
706

mohammad's avatar
mohammad committed
707
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
708
709
710
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
711
712
713
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
714
    # Tensorboard values.
mohammad's avatar
mohammad committed
715
716
717
    if writer and is_last_rank():
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate,
718
                          args.consumed_train_samples)
mohammad's avatar
mohammad committed
719
720
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size,
721
                          args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
722
        for key in loss_dict:
mohammad's avatar
mohammad committed
723
724
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
725
                              args.consumed_train_samples)
726
727
728
        writer.add_scalar('loss-scale', loss_scale, iteration)
        writer.add_scalar('loss-scale vs samples', loss_scale,
                          args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
729
        timers.write(timers_to_log, writer, iteration,
mohammad's avatar
mohammad committed
730
                     normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
731
732
733

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
mohammad's avatar
mohammad committed
734
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
735
        if writer and torch.distributed.get_rank() == 0:
mohammad's avatar
mohammad committed
736
737
            writer.add_scalar('iteration-time',
                              elapsed_time_per_iteration, iteration)
738
739
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
740
        log_string += ' consumed samples: {:12d} |'.format(
741
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
742
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
743
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
744
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
745
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
746
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
747
748
749
750
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
751
752
753
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
754
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
755
756
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
757
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
758
759
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
760
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
761
        total_loss_dict[nan_iters_key] = 0
762
        print_rank_last(log_string)
763
764
765
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
766
767
768
769
770
771
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


772
773
774
775
776
777
778
779
780
781
782
783
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
    timers('save checkpoint').start()
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
    timers('save checkpoint').stop()
    timers.log(['save checkpoint'])


784
def train(forward_step_func, model, optimizer, lr_scheduler,
785
          train_data_iterator, valid_data_iterator):
786
    """Train the model function."""
Mohammad's avatar
Mohammad committed
787
788
    args = get_args()
    timers = get_timers()
789

790
791
792
    # Write args to tensorboard
    write_args_to_tensorboard()

793
794
795
796
797
798
799
800
801
802
    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
803
    print_datetime('before the start of training step')
804
805
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
806
        update_num_microbatches(args.consumed_train_samples)
807
808
809
810
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
811
                                             lr_scheduler)
812
        iteration += 1
813
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
814
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
815
                                       get_num_microbatches()
816
817

        # Logging.
818
        loss_scale = optimizer.get_loss_scale().item()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
819
820
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
821
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
822
                                          report_memory_flag, skipped_iter)
823
824

        # Autoresume
825
826
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
827
            check_adlr_autoresume_termination(iteration, model, optimizer,
828
                                              lr_scheduler)
829
830
831
832
833
834

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
835
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
836
                                       iteration, False)
837

838
839
840
841
842
843
844
845
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
862
        if args.exit_interval and iteration % args.exit_interval == 0:
863
864
865
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
866
            torch.distributed.barrier()
867
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
868
            sys.exit()
869

870

mohammad's avatar
mohammad committed
871
    return iteration
872
873


Mohammad's avatar
Mohammad committed
874
def evaluate(forward_step_func, data_iterator, model, verbose=False):
875
    """Evaluation."""
Mohammad's avatar
Mohammad committed
876
    args = get_args()
877
878
879
880
881
882
883
884
885
886
887
888
889

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
890

mohammad's avatar
mohammad committed
891
            for _ in range(get_num_microbatches()):
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
                if not mpu.is_pipeline_first_stage():
                    input_tensor, _ = communicate(
                        tensor_send_next=None,
                        tensor_send_prev=None,
                        recv_forward=True,
                        recv_backward=False)
                else:
                    input_tensor = None

                # Forward evaluation.
                output_tensor = forward_step_func(data_iterator, model, input_tensor)

                if mpu.is_pipeline_last_stage():
                    _, loss_dict = output_tensor
                    # Reduce across processes.
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                else:
                    communicate(
                        tensor_send_next=output_tensor,
                        tensor_send_prev=None,
                        recv_forward=False,
                        recv_backward=False)
916

917
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
918
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
919
                                           * get_num_microbatches()
920
921
922
923
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
924
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
925
926
927
928
929

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
930
                               iteration, verbose=False):
931
    """Helper function to evaluate and dump results on screen."""
932
    args = get_args()
Mohammad's avatar
Mohammad committed
933
934
935
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
936
937
938
939
940
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
941
942
        if writer and is_last_rank():
            writer.add_scalar('{} value-validation'.format(key),
943
944
                              total_loss_dict[key].item(),
                              iteration)
945
946
947
948
949
950
            writer.add_scalar('{} ppl-validation'.format(key), ppl, iteration)
            writer.add_scalar('{} value-validation vs samples'.format(key),
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
            writer.add_scalar('{} ppl-validation vs samples'.format(key), ppl,
                              args.consumed_train_samples)
951
952

    length = len(string) + 1
953
954
955
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
956
957


958
959
960
961
962
def cyclic_iterable(iterable):
    while True:
        for x in iterable:
            yield x

963
def build_train_valid_test_data_iterators(
964
        build_train_valid_test_datasets_provider, random_sample=False):
965
    """XXX"""
Mohammad's avatar
Mohammad committed
966
    args = get_args()
967

968
969
970
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
971
972
973

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
974
975
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
976
        args.consumed_train_samples = args.iteration * args.global_batch_size
977
    if args.iteration > 0 and args.consumed_valid_samples == 0:
978
979
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
980
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
981
            args.eval_iters * args.global_batch_size
982

983
    # Data loader only on rank 0 of each model parallel group.
984
    if mpu.get_tensor_model_parallel_rank() == 0:
985
986

        # Number of train/valid/test samples.
987
988
989
990
991
992
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
993
        test_iters = args.eval_iters
994
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
995
996
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
997
998
999
1000
1001
1002
1003
1004
1005
1006
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
1007
        train_dataloader = build_pretraining_data_loader(
1008
            train_ds, args.consumed_train_samples, random_sample)
1009
        valid_dataloader = build_pretraining_data_loader(
1010
1011
            valid_ds, args.consumed_valid_samples, random_sample)
        test_dataloader = build_pretraining_data_loader(test_ds, 0, random_sample)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
1025
1026
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
1027
1028
1029
1030
1031
1032
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
    if train_dataloader is not None:
1033
        train_data_iterator = iter(cyclic_iterable(train_dataloader))
1034
1035
1036
    else:
        train_data_iterator = None

1037
    if valid_dataloader is not None:
1038
        valid_data_iterator = iter(cyclic_iterable(valid_dataloader))
1039
    else:
1040
        valid_data_iterator = None
1041

1042
    if test_dataloader is not None:
1043
        test_data_iterator = iter(cyclic_iterable(test_dataloader))
1044
1045
1046
    else:
        test_data_iterator = None

1047
    return train_data_iterator, valid_data_iterator, test_data_iterator