training.py 34.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
from megatron.model import FP16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
mohammad's avatar
mohammad committed
42

Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import forward_backward_no_pipelining
53
from megatron.schedules import forward_backward_pipelining_without_interleaving
54
from megatron.schedules import forward_backward_pipelining_with_interleaving
Mostofa Patwary's avatar
Mostofa Patwary committed
55
from megatron.utils import report_memory
56
57


58
59
60
61
62
63
64
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


65
66
67
68
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
             args_defaults={}):
70
71
72
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
73
74
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
75
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
76
        4) train the modle using the forward_step_func.
77
78

    Arguments:
79
80
81
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
82
83
84
85
86
87
88
89
90
91
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
92
93
    """

94
    # Initalize and get arguments, timers, and Tensorboard writer.
95
96
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
97

98
99
100
101
102
103
104
105
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
106
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
107
108
109
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

110
    args = get_args()
Mohammad's avatar
Mohammad committed
111
    timers = get_timers()
112
113

    # Model, optimizer, and learning rate.
114
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
115
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
116
    timers('model-and-optimizer-setup').stop()
117
118
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
119
120

    # Data stuff.
121
122
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
123
        all_data_iterators = [
124
125
126
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
127
128
129
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
130
131
132
133
134
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
135
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
136
137

    # Print setup timing.
138
139
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
140
    print_rank_0('training ...')
141
142

    iteration = 0
143
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
144
145
146
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
147
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
148

149
150
151
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
152
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
153
                                   iteration, False)
154
155

    if args.save and iteration != 0:
156
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
157
158
159
160
161
162

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
163
                                   0, True)
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
181
182
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
183
184
            iterations += 1
        # Reset
185
        update_num_microbatches(0, consistency_check=False)
186
187
188
189
190
191
192
193
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

194

Mohammad's avatar
Mohammad committed
195
def get_model(model_provider_func):
196
    """Build the model."""
Mohammad's avatar
Mohammad committed
197
    args = get_args()
198
199

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
200
    model = model_provider_func()
201
202
    if not isinstance(model, list):
        model = [model]
203

204
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
205
206
207
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
208
209
210
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
211

212
213
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
214
        print(' > number of parameters on (tensor, pipeline) '
215
              'model parallel rank ({}, {}): {}'.format(
216
217
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
218
219
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
220
221

    # GPU allocation.
222
223
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
224
225
226

    # Fp16 conversion.
    if args.fp16:
227
        model = [FP16Module(model_module) for model_module in model]
228
229
230

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
231
232
233
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
234
235
        return model
    if args.DDP_impl == 'local':
236
        model = [LocalDDP(model_module) for model_module in model]
237
238
        return model

239
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
240
                              'Exiting.'.format(args.DDP_impl))
241
242


Mohammad's avatar
Mohammad committed
243
def get_learning_rate_scheduler(optimizer):
244
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
245
    args = get_args()
246

247
248
249
250
251
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
252
253
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
254
255
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
256
257
258
259
260
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
261
        update_train_iters(args)
262
263
264
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
265
266
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
267
268
        else:
            warmup_steps = args.lr_warmup_samples
269
    else:
270
271
272
        raise Exception(
            'either train-iters or train-samples should be provided.')

273
274
    lr_scheduler = AnnealingLR(
        optimizer,
275
        max_lr=args.lr,
276
        min_lr=args.min_lr,
277
278
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
279
        decay_style=args.lr_decay_style,
280
281
282
283
284
285
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
286
def setup_model_and_optimizer(model_provider_func):
287
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
288
    args = get_args()
289

Mohammad's avatar
Mohammad committed
290
    model = get_model(model_provider_func)
291

292
293
    unwrapped_model = unwrap_model(model,
                                   (torchDDP, LocalDDP, FP16Module))
294
295
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
296
    lr_scheduler = get_learning_rate_scheduler(optimizer)
297
298

    if args.load is not None:
299
300
301
302
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
303
        timers('load-checkpoint').start()
304
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
305
        torch.distributed.barrier()
306
307
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
308
309
310
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
311
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
312
313
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'
314
315
316
317
    if len(model) == 1:
        assert args.DDP_impl == 'local'
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        assert args.DDP_impl == 'local'
mohammad's avatar
mohammad committed
318

Neel Kant's avatar
Neel Kant committed
319
    # get model without FP16 and/or TorchDDP wrappers
320
    model = unwrap_model(model)
321
    for module in model:
322
        if args.iteration == 0 and hasattr(module,
323
324
                                           'init_state_dict_from_bert'):
            print("Initializing ICT from pretrained BERT model", flush=True)
325
            module.init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
326
327
            if args.fp16:
                optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
328

329
330
331
    return model, optimizer, lr_scheduler


332
333
334
335
336
337
338
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
mohammad's avatar
mohammad committed
339
    optimizer.zero_grad()
340
341

    if mpu.get_pipeline_model_parallel_world_size() > 1:
342
343
344
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
345
            forward_backward_func = forward_backward_pipelining_without_interleaving
346
    else:
347
348
349
350
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
351
352
353

    # All-reduce if needed.
    if args.DDP_impl == 'local':
354
        timers('backward-params-all-reduce').start()
355
356
357
        for model_module in model:
            model_module.allreduce_params(reduce_after=False,
                                          fp32_allreduce=args.fp32_allreduce)
358
        timers('backward-params-all-reduce').stop()
359

360
361
362
363
364
    # Barrier to measure backward stall.
    timers('backward-pipeline-stall').start()
    torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
    timers('backward-pipeline-stall').stop()

365
366
367
368
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
369
    timers('backward-embedding-all-reduce').start()
370
371
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
372
            mpu.get_pipeline_model_parallel_world_size() > 1:
373
374
375
376
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
377
378
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, FP16Module))
379

380
381
382
383
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
384
    timers('backward-embedding-all-reduce').stop()
385

386
387
    # Update parameters.
    timers('optimizer').start()
388
    update_successful, grad_norm = optimizer.step()
389
390
391
    timers('optimizer').stop()

    # Update learning rate.
392
    if update_successful:
393
394
395
396
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
397
        skipped_iter = 0
398
399
400
    else:
        skipped_iter = 1

401
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
402
403
404
405
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
406
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
407
408
        return loss_reduced, skipped_iter, grad_norm
    return {}, skipped_iter, grad_norm
409
410


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
411
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
412
413
                 loss_scale, report_memory_flag, skipped_iter,
                 grad_norm, params_norm):
Mohammad's avatar
Mohammad committed
414
415
416
417
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
418

mohammad's avatar
mohammad committed
419
420
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
421
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
422
423
424
425
426
427
428
429
430
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
431
432
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
433
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
434
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
435
    for key in loss_dict:
mohammad's avatar
mohammad committed
436
        if not skipped_iter:
437
438
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
439
440
441
442
443
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
444
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
445
446
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
447
448
449

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
450

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
451
452
453
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
454
    add_to_logging('forward-compute')
455
    add_to_logging('forward-pipeline-stall')
456
457
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
458
    add_to_logging('forward-backward-send-forward-backward-recv')
459
    add_to_logging('backward-compute')
460
    add_to_logging('backward-pipeline-stall')
461
462
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
463
    add_to_logging('backward-send-forward-recv')
464
    add_to_logging('backward-send-backward-recv')
465
    add_to_logging('backward-params-all-reduce')
466
    add_to_logging('backward-embedding-all-reduce')
467
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
468
    add_to_logging('optimizer-unscale-and-check-inf')
469
470
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
471
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
472
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
473

mohammad's avatar
mohammad committed
474
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
475
476
477
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
478
479
480
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
481
    # Tensorboard values.
482
483
484
485
486
487
488
489
490
491
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
492
        for key in loss_dict:
mohammad's avatar
mohammad committed
493
494
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
495
                              args.consumed_train_samples)
496
        if args.log_loss_scale_to_tensorboard:
mohammad's avatar
mohammad committed
497
498
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
499
                              args.consumed_train_samples)
500
501
502
503
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
504
505
506
507
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
508
509
510
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
511
512

    if iteration % args.log_interval == 0:
513
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
514
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
515
        if writer and torch.distributed.get_rank() == 0:
516
517
518
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
519
520
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
521
        log_string += ' consumed samples: {:12d} |'.format(
522
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
523
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
524
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
525
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
526
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
527
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
528
529
530
531
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
532
533
534
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
535
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
536
537
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
mohammad's avatar
mohammad committed
538
539
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
540
541
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
542
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
543
544
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
545
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
546
        total_loss_dict[nan_iters_key] = 0
547
        print_rank_last(log_string)
548
549
550
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
551
552
553
554
555
556
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


557
558
559
560
561
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
562
    timers('save-checkpoint').start()
563
564
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
565
566
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
567
568


569
def train(forward_step_func, model, optimizer, lr_scheduler,
570
          train_data_iterator, valid_data_iterator):
571
    """Train the model function."""
Mohammad's avatar
Mohammad committed
572
573
    args = get_args()
    timers = get_timers()
574

575
576
577
    # Write args to tensorboard
    write_args_to_tensorboard()

578
    # Turn on training mode which enables dropout.
579
580
    for model_module in model:
        model_module.train()
581
582
583
584
585
586
587

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

588
    timers('interval-time').start()
589
    print_datetime('before the start of training step')
590
591
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
592
        update_num_microbatches(args.consumed_train_samples)
593
594
595
596
597
        loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
                                                        train_data_iterator,
                                                        model,
                                                        optimizer,
                                                        lr_scheduler)
598
        iteration += 1
599
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
600
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
601
                                       get_num_microbatches()
602
603

        # Logging.
604
        loss_scale = optimizer.get_loss_scale().item()
605
606
607
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
608
609
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
610
                                          iteration, loss_scale,
611
                                          report_memory_flag, skipped_iter,
mohammad's avatar
mohammad committed
612
                                          grad_norm, params_norm)
613
614

        # Autoresume
615
616
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
617
            check_adlr_autoresume_termination(iteration, model, optimizer,
618
                                              lr_scheduler)
619
620
621
622
623
624

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
625
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
626
                                       iteration, False)
627

628
629
630
631
632
633
634
635
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
652
        if args.exit_interval and iteration % args.exit_interval == 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
653
654
655
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
656
            torch.distributed.barrier()
657
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
658
            sys.exit()
659

660

mohammad's avatar
mohammad committed
661
    return iteration
662
663


Mohammad's avatar
Mohammad committed
664
def evaluate(forward_step_func, data_iterator, model, verbose=False):
665
    """Evaluation."""
Mohammad's avatar
Mohammad committed
666
    args = get_args()
667
668

    # Turn on evaluation mode which disables dropout.
669
670
    for model_module in model:
        model_module.eval()
671
672
673
674
675
676
677
678
679
680

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
681

682
683
684
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
685
                else:
686
                    forward_backward_func = forward_backward_pipelining_without_interleaving
687
688
689
690
691
692
693
694
695
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
696
                    for key in loss_dict:
697
698
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
699

700
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
701
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
702
                                           * get_num_microbatches()
703
    # Move model back to the train mode.
704
705
    for model_module in model:
        model_module.train()
706
707

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
708
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
709
710
711
712
713

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
714
                               iteration, verbose=False):
715
    """Helper function to evaluate and dump results on screen."""
716
    args = get_args()
Mohammad's avatar
Mohammad committed
717
718
719
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
720
721
722
723
724
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
725
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
726
            writer.add_scalar('{} validation'.format(key),
727
728
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
729
            writer.add_scalar('{} validation vs samples'.format(key),
730
731
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
732
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
733
                writer.add_scalar('{} validation ppl'.format(key), ppl,
734
                                  iteration)
mohammad's avatar
mohammad committed
735
                writer.add_scalar('{} validation ppl vs samples'.format(key),
736
                                  ppl, args.consumed_train_samples)
737
738

    length = len(string) + 1
739
740
741
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
742
743


Vijay Korthikanti's avatar
Vijay Korthikanti committed
744
def cyclic_iter(iter):
745
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
746
        for x in iter:
747
748
            yield x

749
750
751
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
752
    args = get_args()
753

754
755
756
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
757
758
759

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
760
761
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
762
        args.consumed_train_samples = args.iteration * args.global_batch_size
763
    if args.iteration > 0 and args.consumed_valid_samples == 0:
764
765
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
766
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
767
            args.eval_iters * args.global_batch_size
768

769
    # Data loader only on rank 0 of each model parallel group.
770
    if mpu.get_tensor_model_parallel_rank() == 0:
771
772

        # Number of train/valid/test samples.
773
774
775
776
777
778
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
779
        test_iters = args.eval_iters
780
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
781
782
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
783
784
785
786
787
788
789
790
791
792
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
793
794
795
796
797
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
798
799
800
801
802
803
804
805
806
807
808
809
810

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
811
812
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
813
814
815
816
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
817

818
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
819
820
821
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

822
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
823
824
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
825
826
827
    else:
        train_data_iterator = None

828
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
829
830
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
831
    else:
832
        valid_data_iterator = None
833

834
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
835
836
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
837
838
839
    else:
        test_data_iterator = None

840
    return train_data_iterator, valid_data_iterator, test_data_iterator