training.py 34.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
42
from megatron.initialize import initialize_megatron
43
from megatron.initialize import write_args_to_tensorboard
44
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
47
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
49
from megatron.utils import calc_params_l2_norm
50
from megatron.schedules import forward_backward_no_pipelining
51
from megatron.schedules import forward_backward_pipelining_without_interleaving
52
from megatron.schedules import forward_backward_pipelining_with_interleaving
53
from megatron.utils import report_memory
54
55


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
65
66
67
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
             args_defaults={}):
69
70
71
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
72
73
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
74
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
75
        4) train the modle using the forward_step_func.
76
77

    Arguments:
78
79
80
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
81
82
83
84
85
86
87
88
89
90
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
91
92
    """

93
    # Initalize and get arguments, timers, and Tensorboard writer.
94
95
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
96

97
98
99
100
101
102
103
104
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
105
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
106
107
108
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

109
    args = get_args()
Mohammad's avatar
Mohammad committed
110
    timers = get_timers()
111
112

    # Model, optimizer, and learning rate.
113
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
114
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
115
    timers('model-and-optimizer-setup').stop()
116
117
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
118
119

    # Data stuff.
120
121
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
122
        all_data_iterators = [
123
124
125
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
126
127
128
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
129
130
131
132
133
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
134
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
135
136

    # Print setup timing.
137
138
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
139
    print_rank_0('training ...')
140
141

    iteration = 0
142
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
143
144
145
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
146
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
147

148
149
150
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
151
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
152
                                   iteration, False)
153
154

    if args.save and iteration != 0:
155
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
156
157
158
159
160
161

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
162
                                   0, True)
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
180
181
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
182
183
            iterations += 1
        # Reset
184
        update_num_microbatches(0, consistency_check=False)
185
186
187
188
189
190
191
192
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

193

Mohammad's avatar
Mohammad committed
194
def get_model(model_provider_func):
195
    """Build the model."""
Mohammad's avatar
Mohammad committed
196
    args = get_args()
197
198

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
199
    model = model_provider_func()
200
201
    if not isinstance(model, list):
        model = [model]
202

203
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
204
205
206
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
207
208
209
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
210

211
212
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
213
        print(' > number of parameters on (tensor, pipeline) '
214
              'model parallel rank ({}, {}): {}'.format(
215
216
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
217
218
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
219
220

    # GPU allocation.
221
222
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
223
224

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
225
226
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
227
228
229

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
230
231
232
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
233
        return model
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
234
    
235
    if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
236
237
238
239
        model = [LocalDDP(model_module,
                          args.accumulate_allreduce_grads_in_fp32,
                          args.use_contiguous_buffers_in_ddp)
                 for model_module in model]
240
241
        return model

242
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
243
                              'Exiting.'.format(args.DDP_impl))
244
245


Mohammad's avatar
Mohammad committed
246
def get_learning_rate_scheduler(optimizer):
247
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
248
    args = get_args()
249

250
251
252
253
254
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
255
256
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
257
258
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
259
260
261
262
263
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
264
        update_train_iters(args)
265
266
267
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
268
269
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
270
271
        else:
            warmup_steps = args.lr_warmup_samples
272
    else:
273
274
275
        raise Exception(
            'either train-iters or train-samples should be provided.')

276
277
    lr_scheduler = AnnealingLR(
        optimizer,
278
        max_lr=args.lr,
279
        min_lr=args.min_lr,
280
281
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
282
        decay_style=args.lr_decay_style,
283
284
285
286
287
288
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
289
def setup_model_and_optimizer(model_provider_func):
290
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
291
    args = get_args()
292

Mohammad's avatar
Mohammad committed
293
    model = get_model(model_provider_func)
294

295
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
296
                                   (torchDDP, LocalDDP, Float16Module))
297
298
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
299
    lr_scheduler = get_learning_rate_scheduler(optimizer)
300
301

    if args.load is not None:
302
303
304
305
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
306
        timers('load-checkpoint').start()
307
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
308
        torch.distributed.barrier()
309
310
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
311
312
313
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
314
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
315
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
316
317
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
318
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
319
320
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
321
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
322
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
323
324
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
325

326
327
328
    return model, optimizer, lr_scheduler


329
330
331
332
333
334
335
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
336
337
338
339
340
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    else:
        optimizer.zero_grad()
341
342

    if mpu.get_pipeline_model_parallel_world_size() > 1:
343
344
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
345
346
347
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
348
        else:
349
            forward_backward_func = forward_backward_pipelining_without_interleaving
350
    else:
351
352
353
354
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
355
356
357

    # All-reduce if needed.
    if args.DDP_impl == 'local':
358
        timers('backward-params-all-reduce').start()
359
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
360
            model_module.allreduce_gradients()
361
        timers('backward-params-all-reduce').stop()
362

363
364
365
366
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
367
    timers('backward-embedding-all-reduce').start()
368
369
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
370
            mpu.get_pipeline_model_parallel_world_size() > 1:
371
372
373
374
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
375
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
376
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
377

378
379
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
380
381
382
383
384
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
385
    timers('backward-embedding-all-reduce').stop()
386

387
388
    # Update parameters.
    timers('optimizer').start()
389
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
390
391
392
    timers('optimizer').stop()

    # Update learning rate.
393
    if update_successful:
394
395
396
397
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
398
        skipped_iter = 0
399
400
401
    else:
        skipped_iter = 1

402
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
403
404
405
406
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
407
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
408
409
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
410
411


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
412
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
413
                 loss_scale, report_memory_flag, skipped_iter,
414
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
415
416
417
418
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
419

mohammad's avatar
mohammad committed
420
421
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
422
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
423
424
425
426
427
428
429
430
431
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
432
433
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
434
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
435
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
436
    for key in loss_dict:
mohammad's avatar
mohammad committed
437
        if not skipped_iter:
438
439
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
440
441
442
443
444
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
445
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
446
447
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
448
449
450

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
451

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
452
453
454
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
455
456
457
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
458
    add_to_logging('forward-backward-send-forward-backward-recv')
459
460
461
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
462
    add_to_logging('backward-send-forward-recv')
463
    add_to_logging('backward-send-backward-recv')
464
    add_to_logging('backward-params-all-reduce')
465
    add_to_logging('backward-embedding-all-reduce')
466
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
467
    add_to_logging('optimizer-unscale-and-check-inf')
468
469
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
470
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
471
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
472

mohammad's avatar
mohammad committed
473
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
474
475
476
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
477
478
479
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
480
    # Tensorboard values.
481
482
483
484
485
486
487
488
489
490
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
491
        for key in loss_dict:
mohammad's avatar
mohammad committed
492
493
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
494
                              args.consumed_train_samples)
495
496
497
498
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
499
500
501
502
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
503
504
505
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
506
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
507
508
509
510
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
511
512
513
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
514
515

    if iteration % args.log_interval == 0:
516
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
517
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
518
        if writer and torch.distributed.get_rank() == 0:
519
520
521
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
522
523
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
524
        log_string += ' consumed samples: {:12d} |'.format(
525
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
526
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
527
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
528
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
529
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
530
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
531
532
533
534
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
535
536
537
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
538
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
539
540
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
541
542
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
543
544
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
545
546
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
547
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
548
549
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
550
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
551
        total_loss_dict[nan_iters_key] = 0
552
        print_rank_last(log_string)
553
554
555
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
556
557
558
559
560
561
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


562
563
564
565
566
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
567
    timers('save-checkpoint').start()
568
569
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
570
571
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
572
573


574
def train(forward_step_func, model, optimizer, lr_scheduler,
575
          train_data_iterator, valid_data_iterator):
576
    """Train the model function."""
Mohammad's avatar
Mohammad committed
577
578
    args = get_args()
    timers = get_timers()
579

580
581
582
    # Write args to tensorboard
    write_args_to_tensorboard()

583
    # Turn on training mode which enables dropout.
584
585
    for model_module in model:
        model_module.train()
586
587
588
589
590
591
592

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

593
    timers('interval-time').start()
594
    print_datetime('before the start of training step')
595
596
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
597
        update_num_microbatches(args.consumed_train_samples)
598
599
600
601
602
603
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
604
        iteration += 1
605
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
606
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
607
                                       get_num_microbatches()
608
609

        # Logging.
610
        loss_scale = optimizer.get_loss_scale().item()
611
612
613
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
614
615
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
616
                                          iteration, loss_scale,
617
                                          report_memory_flag, skipped_iter,
618
                                          grad_norm, params_norm, num_zeros_in_grad)
619
620

        # Autoresume
621
622
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
623
            check_adlr_autoresume_termination(iteration, model, optimizer,
624
                                              lr_scheduler)
625
626
627
628
629
630

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
631
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
632
                                       iteration, False)
633

634
635
636
637
638
639
640
641
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
658
        if args.exit_interval and iteration % args.exit_interval == 0:
659
660
661
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
662
            torch.distributed.barrier()
663
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
664
            sys.exit()
665

666

mohammad's avatar
mohammad committed
667
    return iteration
668
669


Mohammad's avatar
Mohammad committed
670
def evaluate(forward_step_func, data_iterator, model, verbose=False):
671
    """Evaluation."""
Mohammad's avatar
Mohammad committed
672
    args = get_args()
673
674

    # Turn on evaluation mode which disables dropout.
675
676
    for model_module in model:
        model_module.eval()
677
678
679
680
681
682
683
684
685
686

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
687

688
689
690
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
691
                else:
692
                    forward_backward_func = forward_backward_pipelining_without_interleaving
693
694
695
696
697
698
699
700
701
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
702
                    for key in loss_dict:
703
704
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
705

706
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
707
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
708
                                           * get_num_microbatches()
709
    # Move model back to the train mode.
710
711
    for model_module in model:
        model_module.train()
712
713

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
714
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
715
716
717
718
719

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
720
                               iteration, verbose=False):
721
    """Helper function to evaluate and dump results on screen."""
722
    args = get_args()
Mohammad's avatar
Mohammad committed
723
724
725
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
726
727
728
729
730
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
731
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
732
            writer.add_scalar('{} validation'.format(key),
733
734
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
735
            writer.add_scalar('{} validation vs samples'.format(key),
736
737
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
738
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
739
                writer.add_scalar('{} validation ppl'.format(key), ppl,
740
                                  iteration)
mohammad's avatar
mohammad committed
741
                writer.add_scalar('{} validation ppl vs samples'.format(key),
742
                                  ppl, args.consumed_train_samples)
743
744

    length = len(string) + 1
745
746
747
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
748
749


Vijay Korthikanti's avatar
Vijay Korthikanti committed
750
def cyclic_iter(iter):
751
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
752
        for x in iter:
753
754
            yield x

755
756
757
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
758
    args = get_args()
759

760
761
762
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
763
764
765

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
766
767
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
768
        args.consumed_train_samples = args.iteration * args.global_batch_size
769
    if args.iteration > 0 and args.consumed_valid_samples == 0:
770
771
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
772
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
773
            args.eval_iters * args.global_batch_size
774

775
    # Data loader only on rank 0 of each model parallel group.
776
    if mpu.get_tensor_model_parallel_rank() == 0:
777
778

        # Number of train/valid/test samples.
779
780
781
782
783
784
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
785
        test_iters = args.eval_iters
786
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
787
788
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
789
790
791
792
793
794
795
796
797
798
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
799
800
801
802
803
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
804
805
806
807
808
809
810
811
812
813
814
815
816

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
817
818
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
819
820
821
822
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
823

824
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
825
826
827
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

828
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
829
830
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
831
832
833
    else:
        train_data_iterator = None

834
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
835
836
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
837
    else:
838
        valid_data_iterator = None
839

840
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
841
842
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
843
844
845
    else:
        test_data_iterator = None

846
    return train_data_iterator, valid_data_iterator, test_data_iterator