training.py 33.2 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
25
from megatron import get_args
Mohammad's avatar
Mohammad committed
26
27
from megatron import get_timers
from megatron import get_tensorboard_writer
28
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
29
from megatron import print_rank_0
30
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
31
32
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
33
34
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
35
from megatron.initialize import initialize_megatron
36
37
38
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
39
from megatron.model.realm_model import ICTBertModel
40
from megatron.utils import check_adlr_autoresume_termination
41
from megatron.utils import make_data_loader
42
from megatron.utils import report_memory
43
44


45
def pretrain(train_valid_test_dataset_provider, model_provider,
46
             forward_step_func, extra_args_provider=None, args_defaults={}):
47
48
49
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
50
51
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
52
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
53
        4) train the modle using the forward_step_func.
54
55

    Arguments:
56
57
58
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
59
60
61
62
63
64
65
66
67
68
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
69
70
    """

71
    # Initalize and get arguments, timers, and Tensorboard writer.
72
73
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
74

75
    args = get_args()
Mohammad's avatar
Mohammad committed
76
    timers = get_timers()
77
78

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
79
80
81
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
82
83

    # Data stuff.
84
85
86
87
88
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
Mohammad's avatar
Mohammad committed
89
90
91

    # Print setup timing.
    print_rank_0('done with setups ...')
92
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
93
    print_rank_0('training ...')
94
95

    iteration = 0
96
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
97
98
99
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
Mohammad's avatar
Mohammad committed
100

101
102
103
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
104
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
105
                                   iteration, False)
106
107

    if args.save and iteration != 0:
108
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
109
110
111
112
113
114

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
115
                                   0, True)
116
117


Mohammad's avatar
Mohammad committed
118
def get_model(model_provider_func):
119
    """Build the model."""
Mohammad's avatar
Mohammad committed
120
    args = get_args()
121
122

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
123
    model = model_provider_func()
124
125
126

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
127
        print(' > number of parameters on (tensor, pipeline) '
128
              'model parallel rank ({}, {}): {}'.format(
129
130
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
131
132
133
134
135
136
137
138
139
140
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training."""
141
    if args.num_microbatches_in_minibatch > 1:
142
143
        assert args.DDP_impl == 'local'

144
145
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
146
147
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
148
149
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
150
        model = LocalDDP(model)
151
152
        return model

153
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
154
                              'Exiting.'.format(args.DDP_impl))
155
156


Mohammad's avatar
Mohammad committed
157
def get_optimizer(model):
158
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
159
    args = get_args()
160
161

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
162
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
163
164
165
166
167
168
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
169
170
            if not hasattr(param, 'tensor_model_parallel'):
                param.tensor_model_parallel = False
171
172

    # Use Adam.
173
174
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
175
176
177
178
179
180
181
182

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
183
                                       'min_scale': args.min_scale,
184
185
186
187
188
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
189
def get_learning_rate_scheduler(optimizer):
190
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
191
    args = get_args()
192
193
194
195
196
197
198

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
Mohammad's avatar
Mohammad committed
199
    init_step = 0
200
201
202
203
204
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(
        optimizer,
        start_lr=args.lr,
        warmup_iter=warmup_iter,
Mohammad's avatar
Mohammad committed
205
        total_iters=num_iters,
206
207
208
209
210
211
212
213
214
        decay_style=args.lr_decay_style,
        last_iter=init_step,
        min_lr=args.min_lr,
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
215
def setup_model_and_optimizer(model_provider_func):
216
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
217
    args = get_args()
218

Mohammad's avatar
Mohammad committed
219
220
221
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
222
223

    if args.load is not None:
224
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
225
226
227
    else:
        args.iteration = 0

Neel Kant's avatar
Neel Kant committed
228
229
230
231
232
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

233
    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
234
        print("Initializing ICT from pretrained BERT model", flush=True)
235
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
236

237
238
239
    return model, optimizer, lr_scheduler


240
241
242
243
244
245
246
247
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
248
    tensor_shape = (args.seq_length, args.batch_size, args.hidden_size)
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
                                       dtype=args.params_dtype).cuda()
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
                                       dtype=args.params_dtype).cuda()

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
263
                                    group=mpu.get_pipeline_model_parallel_group())
264
265
266
267
268

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
269
    """Backward step."""
Mohammad's avatar
Mohammad committed
270
271
    args = get_args()
    timers = get_timers()
272

273
274
275
276
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

277
    # Backward pass.
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    if args.fp16:
        optimizer.backward(output_tensor, update_master_grads=False,
                           output_tensor_grad=output_tensor_grad)
    else:
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


292
293
294
295
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
    if not mpu.is_pipeline_first_stage():
296
        timers('forward-recv').start()
297
298
299
300
301
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
302
        timers('forward-recv').stop()
303
304
305
306
    else:
        input_tensor = None

    # Forward model for one step.
307
    timers('forward-compute').start()
308
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
309
    timers('forward-compute').stop()
310
311
312
313
314
315

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
        output_tensor = loss
        losses_reduced.append(loss_reduced)
    else:
316
        timers('forward-send').start()
317
318
319
320
321
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
322
        timers('forward-send').stop()
323
324
325
326
327
328
329
330
331
332
333
334
335

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    """Backward step."""
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
336
        timers('backward-recv').start()
337
338
339
340
341
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
342
        timers('backward-recv').stop()
343
344

    # Backward pass for one step.
345
    timers('backward-compute').start()
346
347
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
348
    timers('backward-compute').stop()
349
350

    if not mpu.is_pipeline_first_stage():
351
        timers('backward-send').start()
352
353
354
355
356
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
357
        timers('backward-send').stop()
358
359


360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
        output_tensor = loss
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
        timers('forward-send').start()
        timers('backward-recv').start()
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
        timers('forward-send').stop()
        timers('backward-recv').stop()

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
        timers('backward-send').start()
        timers('forward-recv').start()
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
        timers('backward-send').stop()
        timers('forward-recv').stop()
    else:
        input_tensor = None

    return input_tensor


414
415
416
417
418
419
420
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
421
    if args.fp16:
mohammad's avatar
mohammad committed
422
        optimizer.zero_grad(set_grads_to_None=True)
423
    else:
mohammad's avatar
mohammad committed
424
        optimizer.zero_grad()
425
426

    # Compute number of microbatches in a minibatch.
427
    num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
428
429
430
431
432
433
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches_in_minibatch)
434
435
436
437
438

    input_tensors = []
    output_tensors = []
    losses_reduced = []

439
    # Run warmup forward passes.
440
    timers('forward').start()
441
    for i in range(num_warmup_microbatches):
442
443
444
445
446
447
        if args.pipeline_model_parallel_size > 1:
            forward_step_with_communication(
                forward_step_func, data_iterator, model,
                input_tensors, output_tensors,
                losses_reduced, timers)
        else:
448
            timers('forward-compute').start()
449
450
451
452
453
454
            input_tensor = None
            loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor)
            output_tensor = loss
            losses_reduced.append(loss_reduced)
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
455
            timers('forward-compute').stop()
456
    timers('forward').stop()
457

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    # Before running 1F1B, need to receive first forward tensor.
    if (num_microbatches_in_minibatch - num_warmup_microbatches) > 0:
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)

    # Run 1F1B.
    for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
        last_iteration = (i == (num_microbatches_in_minibatch - num_warmup_microbatches - 1))
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

478
    # Run cooldown backward passes.
479
    timers('backward').start()
480
    for i in range(num_warmup_microbatches):
481
482
483
484
        if args.pipeline_model_parallel_size > 1:
            backward_step_with_communication(
                optimizer, model, input_tensors, output_tensors, timers)
        else:
485
            timers('backward-compute').start()
486
487
488
489
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
            output_tensor_grad = None
            backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
490
            timers('backward-compute').stop()
491
492
493

    # All-reduce if needed.
    if args.DDP_impl == 'local':
494
        timers('backward-params-all-reduce').start()
495
496
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
497
        timers('backward-params-all-reduce').stop()
498

499
500
501
502
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
503
    timers('backward-embedding-all-reduce').start()
504
505
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
            args.pipeline_model_parallel_size > 1:
506
507
508
509
510
511
512
        unwrapped_model = model
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
            unwrapped_model = unwrapped_model.module

        word_embeddings_weight = unwrapped_model.word_embeddings_weight()
        torch.distributed.all_reduce(word_embeddings_weight.grad,
                                     group=mpu.get_embedding_group())
513
    timers('backward-embedding-all-reduce').stop()
514

515
516
517
518
519
520
    # Update master gradients.
    timers('backward-master-grad').start()
    if args.fp16:
        optimizer.update_master_grads()
    timers('backward-master-grad').stop()

521
    # Clipping gradients helps prevent the exploding gradient.
522
    timers('backward-clip-grad').start()
523
    if args.clip_grad > 0.:
524
        if not args.fp16:
525
526
527
528
529
530
531
532
            named_parameters = model.named_parameters()
            parameters = []
            parameter_names = []
            for parameter_name, parameter in model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
533
534
        else:
            optimizer.clip_master_grads(args.clip_grad)
535
    timers('backward-clip-grad').stop()
536
    timers('backward').stop()
537
538
539
540
541
542
543
544
545
546
547
548
549

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
        lr_scheduler.step()
    else:
        skipped_iter = 1

550
    if mpu.is_pipeline_last_stage():
551
552
553
554
555
556
557
558
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
            loss_reduced[key] = sum(losses_reduced_for_key) / \
                    len(losses_reduced_for_key)
        return loss_reduced, skipped_iter
    return {}, skipped_iter
559
560


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
561
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
562
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
563
564
565
566
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
567
568

    # Update losses.
mohammad's avatar
mohammad committed
569
570
571
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
572
    got_nan_key = 'got nan'
mohammad's avatar
mohammad committed
573
574

    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
575
    for key in loss_dict:
mohammad's avatar
mohammad committed
576
        if not skipped_iter:
577
578
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
579
580
581
582
583
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
584
585
586
587
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
588
589
590

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
591

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
592
593
594
595
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
596
597
598
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
599
    add_to_logging('backward')
600
601
602
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
603
    add_to_logging('backward-master-grad')
604
    add_to_logging('backward-params-all-reduce')
605
    add_to_logging('backward-embedding-all-reduce')
606
    add_to_logging('backward-clip-grad')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
633
634
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
635
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
636
            if key not in [skipped_iters_key, got_nan_key]:
mohammad's avatar
mohammad committed
637
                avg = total_loss_dict[key].item() / float(num_iterations)
638
639
640
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
641
642
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
643
644
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
645
646
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
mohammad's avatar
mohammad committed
647
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
648
        total_loss_dict[got_nan_key] = 0
649
        print_rank_last(log_string)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
650
651
652
653
654
655
656
657
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


658
def train(forward_step_func, model, optimizer, lr_scheduler,
659
          train_data_iterator, valid_data_iterator):
660
    """Train the model function."""
Mohammad's avatar
Mohammad committed
661
662
    args = get_args()
    timers = get_timers()
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
680
                                             lr_scheduler)
681
682
683
        iteration += 1

        # Logging.
Mohammad's avatar
Mohammad committed
684
685
686
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
687
688
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
689
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
690
                                          report_memory_flag, skipped_iter)
691
692

        # Autoresume
693
694
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
695
            check_adlr_autoresume_termination(iteration, model, optimizer,
696
                                              lr_scheduler)
697
698
699
700

        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
701
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
702
703
704
705
706
707

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
708
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
709
                                       iteration, False)
710
711

        if args.exit_interval and iteration % args.exit_interval == 0:
712
            torch.distributed.barrier()
713
714
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
Mohammad's avatar
Mohammad committed
715
716
717
            print_rank_0('rank: {} | time: {} | exiting the program at '
                         'iteration {}'.format(rank, time_str, iteration))
            sys.exit()
718

mohammad's avatar
mohammad committed
719
    return iteration
720
721


Mohammad's avatar
Mohammad committed
722
def evaluate(forward_step_func, data_iterator, model, verbose=False):
723
    """Evaluation."""
Mohammad's avatar
Mohammad committed
724
    args = get_args()
725
726
727
728
729
730
731
732
733
734
735
736
737

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
738

739
            if not mpu.is_pipeline_first_stage():
740
741
742
743
744
745
746
747
                input_tensor, _ = communicate(
                    tensor_send_next=None,
                    tensor_send_prev=None,
                    recv_forward=True,
                    recv_backward=False)
            else:
                input_tensor = None

748
            # Forward evaluation.
749
750
            output_tensor = forward_step_func(data_iterator, model, input_tensor)

751
            if mpu.is_pipeline_last_stage():
752
753
754
755
756
757
758
759
760
761
762
763
                _, loss_dict = output_tensor
                # Reduce across processes.
                for key in loss_dict:
                    total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
                        loss_dict[key]
            else:
                communicate(
                    tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)

764
765
766
767
768
769
770
771
772
773
774
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters

    return total_loss_dict


def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
775
                               iteration, verbose=False):
776
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
777
778
779
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
780
781
782
783
784
785
786
787
788
789
790
791
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
792
793
794
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
795
796


797
798
799
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
800
    args = get_args()
801

802
803
804
805
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
806
    if mpu.get_tensor_model_parallel_rank() == 0:
807
808
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
809
        global_batch_size = args.batch_size * data_parallel_size * args.num_microbatches_in_minibatch
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
844
845
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
846
847
848
849
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Mohammad's avatar
Mohammad committed
850
    # Shift the start iterations.
851
852
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
Neel Kant's avatar
Neel Kant committed
853
            len(train_dataloader)
Mohammad's avatar
Mohammad committed
854
        print_rank_0('setting training data start iteration to {}'.
855
856
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
Mohammad's avatar
Mohammad committed
857
        start_iter_val = (args.iteration // args.eval_interval) * \
Neel Kant's avatar
Neel Kant committed
858
            args.eval_iters
859
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
Neel Kant's avatar
Neel Kant committed
860
            len(valid_dataloader)
Mohammad's avatar
Mohammad committed
861
        print_rank_0('setting validation data start iteration to {}'.
862
                     format(valid_dataloader.batch_sampler.start_iter))
863

864
865
866
    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
867
868
869
    else:
        train_data_iterator = None

870
871
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
872
    else:
873
        valid_data_iterator = None
874

875
876
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
877
878
879
    else:
        test_data_iterator = None

880
    return train_data_iterator, valid_data_iterator, test_data_iterator