training.py 39.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
from megatron.model import FP16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
mohammad's avatar
mohammad committed
42

Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.data.data_loaders import build_pretraining_data_loader
50
from megatron.utils import report_memory
51
52


53
54
55
56
57
58
59
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


60
def pretrain(train_valid_test_dataset_provider, model_provider,
61
             forward_step_func, extra_args_provider=None, args_defaults={}):
62
63
64
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
65
66
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
67
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
68
        4) train the modle using the forward_step_func.
69
70

    Arguments:
71
72
73
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
74
75
76
77
78
79
80
81
82
83
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
84
85
    """

86
    # Initalize and get arguments, timers, and Tensorboard writer.
87
88
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
89

90
91
92
93
94
95
96
97
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
98
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
99
100
101
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

102
    args = get_args()
Mohammad's avatar
Mohammad committed
103
    timers = get_timers()
104
105

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
106
107
108
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
109
110
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
111
112

    # Data stuff.
113
114
115
116
117
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
mshoeybi's avatar
mshoeybi committed
118
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
119
120
121

    # Print setup timing.
    print_rank_0('done with setups ...')
122
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
123
    print_rank_0('training ...')
124
125

    iteration = 0
126
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
127
128
129
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
130
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
131

132
133
134
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
135
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
136
                                   iteration, False)
137
138

    if args.save and iteration != 0:
139
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
140
141
142
143
144
145

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
146
                                   0, True)
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
164
165
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
166
167
            iterations += 1
        # Reset
168
        update_num_microbatches(0, consistency_check=False)
169
170
171
172
173
174
175
176
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

177

Mohammad's avatar
Mohammad committed
178
def get_model(model_provider_func):
179
    """Build the model."""
Mohammad's avatar
Mohammad committed
180
    args = get_args()
181
182

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
183
    model = model_provider_func()
184

185
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
186
187
188
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
189
190
191
    for param in model.parameters():
        mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)

192
193
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
194
        print(' > number of parameters on (tensor, pipeline) '
195
              'model parallel rank ({}, {}): {}'.format(
196
197
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
198
199
200
201
202
203
204
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
205
        model = FP16Module(model)
206
207
208

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
209
210
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
211
212
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
213
        model = LocalDDP(model)
214
215
        return model

216
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
217
                              'Exiting.'.format(args.DDP_impl))
218
219


Mohammad's avatar
Mohammad committed
220
def get_learning_rate_scheduler(optimizer):
221
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
222
    args = get_args()
223

224
225
226
227
228
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
229
230
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
231
232
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
233
234
235
236
237
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
238
        update_train_iters(args)
239
240
241
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
242
243
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
244
245
        else:
            warmup_steps = args.lr_warmup_samples
246
    else:
247
248
249
        raise Exception(
            'either train-iters or train-samples should be provided.')

250
251
    lr_scheduler = AnnealingLR(
        optimizer,
252
        max_lr=args.lr,
253
        min_lr=args.min_lr,
254
255
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
256
        decay_style=args.lr_decay_style,
257
258
259
260
261
262
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
263
def setup_model_and_optimizer(model_provider_func):
264
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
265
    args = get_args()
266

Mohammad's avatar
Mohammad committed
267
    model = get_model(model_provider_func)
268
269

    unwrapped_model = model
270
    while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
271
272
273
        unwrapped_model = unwrapped_model.module
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
274
    lr_scheduler = get_learning_rate_scheduler(optimizer)
275
276

    if args.load is not None:
277
278
279
280
281
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
        timers('load checkpoint').start()
282
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
283
284
285
        torch.distributed.barrier()
        timers('load checkpoint').stop()
        timers.log(['load checkpoint'])
286
287
288
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
289
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
290
291
292
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
293
294
295
296
297
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

298
299
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
300
        print("Initializing ICT from pretrained BERT model", flush=True)
301
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
302

303
304
305
    return model, optimizer, lr_scheduler


306
307
308
309
310
311
312
313
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
314
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
315
316
317
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
318
319
320
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
321
                                       device=torch.cuda.current_device(),
322
                                       dtype=dtype)
323
324
325
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
326
                                       device=torch.cuda.current_device(),
327
                                       dtype=dtype)
328
329
330
331
332
333

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
334
                                    group=mpu.get_pipeline_model_parallel_group())
335
336
337
338
339

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
340
    """Backward step."""
Mohammad's avatar
Mohammad committed
341
342
    args = get_args()
    timers = get_timers()
343

344
345
346
347
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

348
    # Backward pass.
mohammad's avatar
mohammad committed
349
350
351
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
352

353
354
355
356
357
358
359
360
    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


361
362
363
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
364
365
    args = get_args()

366
    if not mpu.is_pipeline_first_stage():
367
        timers('forward-recv').start()
368
369
370
371
372
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
373
        timers('forward-recv').stop()
374
375
376
377
    else:
        input_tensor = None

    # Forward model for one step.
378
    timers('forward-compute').start()
379
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
380
    timers('forward-compute').stop()
381
382
383

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
384
        output_tensor = loss / get_num_microbatches()
385
386
        losses_reduced.append(loss_reduced)
    else:
387
        timers('forward-send').start()
388
389
390
391
392
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
393
        timers('forward-send').stop()
394
395
396
397
398
399
400
401
402
403
404
405

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
406
        timers('backward-recv').start()
407
408
409
410
411
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
412
        timers('backward-recv').stop()
413
414

    # Backward pass for one step.
415
    timers('backward-compute').start()
416
417
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
418
    timers('backward-compute').stop()
419
420

    if not mpu.is_pipeline_first_stage():
421
        timers('backward-send').start()
422
423
424
425
426
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
427
        timers('backward-send').stop()
428
429


430
431
432
433
434
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
435
436
    args = get_args()

437
438
439
440
441
442
443
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
mohammad's avatar
mohammad committed
444
        output_tensor = loss / get_num_microbatches()
445
446
447
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
Deepak Narayanan's avatar
Deepak Narayanan committed
448
        timers('forward-send-backward-recv').start()
449
450
451
452
453
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
Deepak Narayanan's avatar
Deepak Narayanan committed
454
        timers('forward-send-backward-recv').stop()
455
456
457
458
459
460
461
462
463
464
465
466
467
468

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
Deepak Narayanan's avatar
Deepak Narayanan committed
469
        timers('backward-send-forward-recv').start()
470
471
472
473
474
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
Deepak Narayanan's avatar
Deepak Narayanan committed
475
        timers('backward-send-forward-recv').stop()
476
477
478
479
480
481
    else:
        input_tensor = None

    return input_tensor


482
483
484
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers):
    """Run forward and backward passes without inter-stage communication."""
485
486
    args = get_args()

487
    losses_reduced = []
mohammad's avatar
mohammad committed
488
    for i in range(get_num_microbatches()):
489
490
        timers('forward-compute').start()
        loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
mohammad's avatar
mohammad committed
491
        output_tensor = loss / get_num_microbatches()
492
493
494
495
496
497
498
499
500
501
        losses_reduced.append(loss_reduced)
        timers('forward-compute').stop()

        timers('backward-compute').start()
        output_tensor_grad = None
        backward_step(optimizer, model, input_tensor=None,
                      output_tensor=output_tensor, output_tensor_grad=None)
        timers('backward-compute').stop()

    return losses_reduced
502

503
504
505
506
507
508
509

def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    args = get_args()

    # Compute number of warmup microbatches.
mohammad's avatar
mohammad committed
510
    num_microbatches = get_num_microbatches()
511
512
513
514
515
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
516
517
518
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches
519
520
521
522
523

    input_tensors = []
    output_tensors = []
    losses_reduced = []

524
525
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
526
527
528
529
        forward_step_with_communication(
            forward_step_func, data_iterator, model,
            input_tensors, output_tensors,
            losses_reduced, timers)
530

531
    # Before running 1F1B, need to receive first forward tensor.
532
533
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
534
    if num_microbatches_remaining > 0:
535
536
537
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
538
            timers('forward-recv').start()
539
540
541
542
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)
543
            timers('forward-recv').stop()
544
545

    # Run 1F1B.
546
547
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))
548
549
550
551
552
553
554
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

555
556
    # Run cooldown backward passes.
    for i in range(num_warmup_microbatches):
557
558
559
560
561
562
563
564
565
566
567
568
569
        backward_step_with_communication(
            optimizer, model, input_tensors, output_tensors, timers)

    return losses_reduced


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
mohammad's avatar
mohammad committed
570
    optimizer.zero_grad()
571
572
573
574
575
576
577

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
    else:
        losses_reduced = forward_backward_no_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
578
579
580

    # All-reduce if needed.
    if args.DDP_impl == 'local':
581
        timers('backward-params-all-reduce').start()
582
583
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
584
        timers('backward-params-all-reduce').stop()
585

586
587
588
589
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
590
    timers('backward-embedding-all-reduce').start()
591
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
592
            mpu.get_pipeline_model_parallel_world_size() > 1:
593
        unwrapped_model = model
594
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
595
596
            unwrapped_model = unwrapped_model.module

597
598
599
600
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
601
    timers('backward-embedding-all-reduce').stop()
602

603
604
    # Update parameters.
    timers('optimizer').start()
mohammad's avatar
mohammad committed
605
    update_successfull = optimizer.step()
606
607
608
    timers('optimizer').stop()

    # Update learning rate.
mohammad's avatar
mohammad committed
609
    if update_successfull:
610
611
612
613
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
614
        skipped_iter = 0
615
616
617
    else:
        skipped_iter = 1

618
    if mpu.is_pipeline_last_stage():
619
620
621
622
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
623
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
624
625
        return loss_reduced, skipped_iter
    return {}, skipped_iter
626
627


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
628
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
629
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
630
631
632
633
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
634

mohammad's avatar
mohammad committed
635
636
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
637
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
638
639
640
641
642
643
644
645
646
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
647
648
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
649
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
650
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
651
    for key in loss_dict:
mohammad's avatar
mohammad committed
652
        if not skipped_iter:
653
654
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
655
656
657
658
659
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
660
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
661
662
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
663
664
665

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
666

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
667
668
669
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
670
671
672
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
673
    add_to_logging('forward-send-backward-recv')
674
675
676
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
677
    add_to_logging('backward-send-forward-recv')
678
    add_to_logging('backward-params-all-reduce')
679
    add_to_logging('backward-embedding-all-reduce')
mohammad's avatar
mohammad committed
680
681
682
683
    add_to_logging('optimizer-copy-to-master-grad')
    add_to_logging('optimizer-unscale-and-check-inf')
    add_to_logging('optimizer-clip-master-grad')
    add_to_logging('optimizer-copy-master-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
684
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
685
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
686

mohammad's avatar
mohammad committed
687
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
688
689
690
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
691
692
693
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
694
    # Tensorboard values.
mohammad's avatar
mohammad committed
695
696
697
    if writer and is_last_rank():
        writer.add_scalar('learning-rate', learning_rate, iteration)
        writer.add_scalar('learning-rate vs samples', learning_rate,
698
                          args.consumed_train_samples)
mohammad's avatar
mohammad committed
699
700
        writer.add_scalar('batch-size', batch_size, iteration)
        writer.add_scalar('batch-size vs samples', batch_size,
701
                          args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
702
        for key in loss_dict:
mohammad's avatar
mohammad committed
703
704
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
705
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
706
        if args.fp16:
mohammad's avatar
mohammad committed
707
708
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
709
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
710
        timers.write(timers_to_log, writer, iteration,
mohammad's avatar
mohammad committed
711
                     normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
712
713
714

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
mohammad's avatar
mohammad committed
715
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
716
        if writer and torch.distributed.get_rank() == 0:
mohammad's avatar
mohammad committed
717
718
            writer.add_scalar('iteration-time',
                              elapsed_time_per_iteration, iteration)
719
720
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
721
        log_string += ' consumed samples: {:12d} |'.format(
722
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
723
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
724
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
725
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
726
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
727
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
728
729
730
731
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
732
733
734
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
735
736
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
737
738
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
739
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
740
741
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
742
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
743
        total_loss_dict[nan_iters_key] = 0
744
        print_rank_last(log_string)
745
746
747
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
748
749
750
751
752
753
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


754
755
756
757
758
759
760
761
762
763
764
765
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
    timers('save checkpoint').start()
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
    timers('save checkpoint').stop()
    timers.log(['save checkpoint'])


766
def train(forward_step_func, model, optimizer, lr_scheduler,
767
          train_data_iterator, valid_data_iterator):
768
    """Train the model function."""
Mohammad's avatar
Mohammad committed
769
770
    args = get_args()
    timers = get_timers()
771

772
773
774
    # Write args to tensorboard
    write_args_to_tensorboard()

775
776
777
778
779
780
781
782
783
784
    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
785
    print_datetime('before the start of training step')
786
787
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
788
        update_num_microbatches(args.consumed_train_samples)
789
790
791
792
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
793
                                             lr_scheduler)
794
        iteration += 1
795
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
796
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
797
                                       get_num_microbatches()
798
799

        # Logging.
Mohammad's avatar
Mohammad committed
800
801
        loss_scale = None
        if args.fp16:
mohammad's avatar
mohammad committed
802
            loss_scale = optimizer.get_loss_scale().item()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
803
804
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
805
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
806
                                          report_memory_flag, skipped_iter)
807
808

        # Autoresume
809
810
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
811
            check_adlr_autoresume_termination(iteration, model, optimizer,
812
                                              lr_scheduler)
813
814
815
816
817
818

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
819
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
820
                                       iteration, False)
821

822
823
824
825
826
827
828
829
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
846
        if args.exit_interval and iteration % args.exit_interval == 0:
847
848
849
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
850
            torch.distributed.barrier()
851
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
852
            sys.exit()
853

854

mohammad's avatar
mohammad committed
855
    return iteration
856
857


Mohammad's avatar
Mohammad committed
858
def evaluate(forward_step_func, data_iterator, model, verbose=False):
859
    """Evaluation."""
Mohammad's avatar
Mohammad committed
860
    args = get_args()
861
862
863
864
865
866
867
868
869
870
871
872
873

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
874

mohammad's avatar
mohammad committed
875
            for _ in range(get_num_microbatches()):
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
                if not mpu.is_pipeline_first_stage():
                    input_tensor, _ = communicate(
                        tensor_send_next=None,
                        tensor_send_prev=None,
                        recv_forward=True,
                        recv_backward=False)
                else:
                    input_tensor = None

                # Forward evaluation.
                output_tensor = forward_step_func(data_iterator, model, input_tensor)

                if mpu.is_pipeline_last_stage():
                    _, loss_dict = output_tensor
                    # Reduce across processes.
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                else:
                    communicate(
                        tensor_send_next=output_tensor,
                        tensor_send_prev=None,
                        recv_forward=False,
                        recv_backward=False)
900

901
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
902
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
903
                                           * get_num_microbatches()
904
905
906
907
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
908
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
909
910
911
912
913

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
914
                               iteration, verbose=False):
915
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
916
917
918
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
919
920
921
922
923
924
925
926
927
928
929
930
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
931
932
933
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
934
935


936
937
938
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
939
    args = get_args()
940

941
942
943
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
944
945
946

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
947
948
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
949
        args.consumed_train_samples = args.iteration * args.global_batch_size
950
    if args.iteration > 0 and args.consumed_valid_samples == 0:
951
952
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
953
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
954
            args.eval_iters * args.global_batch_size
955

956
    # Data loader only on rank 0 of each model parallel group.
957
    if mpu.get_tensor_model_parallel_rank() == 0:
958
959

        # Number of train/valid/test samples.
960
961
962
963
964
965
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
966
        test_iters = args.eval_iters
967
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
968
969
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
970
971
972
973
974
975
976
977
978
979
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
980
981
982
983
984
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
985
986
987
988
989
990
991
992
993
994
995
996
997

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
998
999
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
1000
1001
1002
1003
1004
1005
1006
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
1007
1008
1009
    else:
        train_data_iterator = None

1010
1011
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
1012
    else:
1013
        valid_data_iterator = None
1014

1015
1016
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
1017
1018
1019
    else:
        test_data_iterator = None

1020
    return train_data_iterator, valid_data_iterator, test_data_iterator