training.py 35.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
42
from megatron.initialize import initialize_megatron
43
from megatron.initialize import write_args_to_tensorboard
44
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
47
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
49
from megatron.utils import calc_params_l2_norm
50
from megatron.schedules import forward_backward_no_pipelining
51
from megatron.schedules import forward_backward_pipelining_without_interleaving
52
from megatron.schedules import forward_backward_pipelining_with_interleaving
53
from megatron.utils import report_memory
54
55


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
def pretrain(train_valid_test_dataset_provider,
65
             model_provider,
66
67
             forward_step_func,
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
             args_defaults={}):
69
70
71
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
72
73
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
74
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
75
        4) train the modle using the forward_step_func.
76
77

    Arguments:
78
79
80
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
81
82
83
84
85
86
87
88
89
90
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
91
92
    """

93
    # Initalize and get arguments, timers, and Tensorboard writer.
94
95
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
96

97
98
99
100
101
102
103
104
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
105
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
106
107
108
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

109
    args = get_args()
Mohammad's avatar
Mohammad committed
110
    timers = get_timers()
111
112

    # Model, optimizer, and learning rate.
113
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
114
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
115
    timers('model-and-optimizer-setup').stop()
116
117
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
118
119

    # Data stuff.
120
121
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
122
        all_data_iterators = [
123
124
125
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
126
127
128
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
129
130
131
132
133
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
134
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
135
136

    # Print setup timing.
137
138
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
139
    print_rank_0('training ...')
140
141

    iteration = 0
142
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
143
144
145
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
146
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
147

148
149
150
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
151
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
152
                                   iteration, False)
153
154

    if args.save and iteration != 0:
155
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
156
157
158
159
160
161

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
162
                                   0, True)
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
180
181
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
182
183
            iterations += 1
        # Reset
184
        update_num_microbatches(0, consistency_check=False)
185
186
187
188
189
190
191
192
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

193

Mohammad's avatar
Mohammad committed
194
def get_model(model_provider_func):
195
    """Build the model."""
Mohammad's avatar
Mohammad committed
196
    args = get_args()
197

198
    # Build model.
199
200
201
202
203
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
204
205
206
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
207
            this_model = model_provider_func(
208
209
210
                pre_process=pre_process,
                post_process=post_process
            )
211
            model.append(this_model)
212
    else:
213
214
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
215
216
217
218
219
        model = model_provider_func(
            pre_process=pre_process,
            post_process=post_process
        )

220
221
    if not isinstance(model, list):
        model = [model]
222

223
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
224
225
226
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
227
228
229
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
230

231
232
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
233
        print(' > number of parameters on (tensor, pipeline) '
234
              'model parallel rank ({}, {}): {}'.format(
235
236
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
237
238
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
239
240

    # GPU allocation.
241
242
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
243
244

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
245
246
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
247
248
249

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
250
251
252
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
253
        return model
254

255
    if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
256
257
258
259
        model = [LocalDDP(model_module,
                          args.accumulate_allreduce_grads_in_fp32,
                          args.use_contiguous_buffers_in_ddp)
                 for model_module in model]
260
261
        return model

262
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
263
                              'Exiting.'.format(args.DDP_impl))
264
265


Mohammad's avatar
Mohammad committed
266
def get_learning_rate_scheduler(optimizer):
267
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
268
    args = get_args()
269

270
271
272
273
274
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
275
276
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
277
278
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
279
280
281
282
283
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
284
        update_train_iters(args)
285
286
287
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
288
289
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
290
291
        else:
            warmup_steps = args.lr_warmup_samples
292
    else:
293
294
295
        raise Exception(
            'either train-iters or train-samples should be provided.')

296
297
    lr_scheduler = AnnealingLR(
        optimizer,
298
        max_lr=args.lr,
299
        min_lr=args.min_lr,
300
301
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
302
        decay_style=args.lr_decay_style,
303
304
305
306
307
308
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
309
def setup_model_and_optimizer(model_provider_func):
310
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
311
    args = get_args()
312

Mohammad's avatar
Mohammad committed
313
    model = get_model(model_provider_func)
314

315
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
316
                                   (torchDDP, LocalDDP, Float16Module))
317
318
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
319
    lr_scheduler = get_learning_rate_scheduler(optimizer)
320
321

    if args.load is not None:
322
323
324
325
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
326
        timers('load-checkpoint').start()
327
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
328
        torch.distributed.barrier()
329
330
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
331
332
333
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
334
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
335
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
336
337
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
338
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
339
340
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
341
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
342
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
343
344
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
345

346
347
348
    return model, optimizer, lr_scheduler


349
350
351
352
353
354
355
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
356
357
358
359
360
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    else:
        optimizer.zero_grad()
361
362

    if mpu.get_pipeline_model_parallel_world_size() > 1:
363
364
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
365
366
367
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
368
        else:
369
            forward_backward_func = forward_backward_pipelining_without_interleaving
370
    else:
371
372
373
374
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
375
376
377

    # All-reduce if needed.
    if args.DDP_impl == 'local':
378
        timers('backward-params-all-reduce').start()
379
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
380
            model_module.allreduce_gradients()
381
        timers('backward-params-all-reduce').stop()
382

383
384
385
386
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
387
    timers('backward-embedding-all-reduce').start()
388
389
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
390
            mpu.get_pipeline_model_parallel_world_size() > 1:
391
392
393
394
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
395
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
396
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
397

398
399
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
400
401
402
403
404
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
405
    timers('backward-embedding-all-reduce').stop()
406

407
408
    # Update parameters.
    timers('optimizer').start()
409
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
410
411
412
    timers('optimizer').stop()

    # Update learning rate.
413
    if update_successful:
414
415
416
417
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
418
        skipped_iter = 0
419
420
421
    else:
        skipped_iter = 1

422
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
423
424
425
426
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
427
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
428
429
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
430
431


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
432
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
433
                 loss_scale, report_memory_flag, skipped_iter,
434
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
435
436
437
438
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
439

mohammad's avatar
mohammad committed
440
441
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
442
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
443
444
445
446
447
448
449
450
451
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
452
453
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
454
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
455
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
456
    for key in loss_dict:
mohammad's avatar
mohammad committed
457
        if not skipped_iter:
458
459
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
460
461
462
463
464
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
465
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
466
467
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
468
469
470

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
471

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
472
473
474
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
475
476
477
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
478
    add_to_logging('forward-backward-send-forward-backward-recv')
479
480
481
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
482
    add_to_logging('backward-send-forward-recv')
483
    add_to_logging('backward-send-backward-recv')
484
    add_to_logging('backward-params-all-reduce')
485
    add_to_logging('backward-embedding-all-reduce')
486
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
487
    add_to_logging('optimizer-unscale-and-check-inf')
488
489
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
490
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
491
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
492

mohammad's avatar
mohammad committed
493
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
494
495
496
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
497
498
499
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
500
    # Tensorboard values.
501
502
503
504
505
506
507
508
509
510
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
511
        for key in loss_dict:
mohammad's avatar
mohammad committed
512
513
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
514
                              args.consumed_train_samples)
515
516
517
518
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
519
520
521
522
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
523
524
525
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
526
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
527
528
529
530
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
531
532
533
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
534
535

    if iteration % args.log_interval == 0:
536
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
537
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
538
        if writer and torch.distributed.get_rank() == 0:
539
540
541
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
542
543
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
544
        log_string += ' consumed samples: {:12d} |'.format(
545
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
546
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
547
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
548
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
549
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
551
552
553
554
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
555
556
557
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
558
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
559
560
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
561
562
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
563
564
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
565
566
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
567
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
568
569
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
570
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
571
        total_loss_dict[nan_iters_key] = 0
572
        print_rank_last(log_string)
573
574
575
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
576
577
578
579
580
581
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


582
583
584
585
586
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
587
    timers('save-checkpoint').start()
588
589
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
590
591
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
592
593


594
def train(forward_step_func, model, optimizer, lr_scheduler,
595
          train_data_iterator, valid_data_iterator):
596
    """Train the model function."""
Mohammad's avatar
Mohammad committed
597
598
    args = get_args()
    timers = get_timers()
599

600
601
602
    # Write args to tensorboard
    write_args_to_tensorboard()

603
    # Turn on training mode which enables dropout.
604
605
    for model_module in model:
        model_module.train()
606
607
608
609
610
611
612

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

613
    timers('interval-time').start()
614
    print_datetime('before the start of training step')
615
616
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
617
        update_num_microbatches(args.consumed_train_samples)
618
619
620
621
622
623
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
624
        iteration += 1
625
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
626
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
627
                                       get_num_microbatches()
628
629

        # Logging.
630
        loss_scale = optimizer.get_loss_scale().item()
631
632
633
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
634
635
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
636
                                          iteration, loss_scale,
637
                                          report_memory_flag, skipped_iter,
638
                                          grad_norm, params_norm, num_zeros_in_grad)
639
640

        # Autoresume
641
642
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
643
            check_adlr_autoresume_termination(iteration, model, optimizer,
644
                                              lr_scheduler)
645
646
647
648
649
650

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
651
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
652
                                       iteration, False)
653

654
655
656
657
658
659
660
661
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

662
663
664
665
666
667
668
669
670
671
672
673
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
674
                print_datetime('exiting program after {} minutes'.format(train_time))
675
676
                sys.exit()

677
        # Exiting based on iterations
678
        if args.exit_interval and iteration % args.exit_interval == 0:
679
680
681
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
682
            torch.distributed.barrier()
683
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
684
            sys.exit()
685

686

mohammad's avatar
mohammad committed
687
    return iteration
688
689


Mohammad's avatar
Mohammad committed
690
def evaluate(forward_step_func, data_iterator, model, verbose=False):
691
    """Evaluation."""
Mohammad's avatar
Mohammad committed
692
    args = get_args()
693
694

    # Turn on evaluation mode which disables dropout.
695
696
    for model_module in model:
        model_module.eval()
697
698
699
700
701
702
703
704
705
706

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
707

708
709
710
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
711
                else:
712
                    forward_backward_func = forward_backward_pipelining_without_interleaving
713
714
715
716
717
718
719
720
721
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
722
                    for key in loss_dict:
723
724
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
725

726
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
727
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
728
                                           * get_num_microbatches()
729
    # Move model back to the train mode.
730
731
    for model_module in model:
        model_module.train()
732
733

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
734
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
735
736
737
738
739

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
740
                               iteration, verbose=False):
741
    """Helper function to evaluate and dump results on screen."""
742
    args = get_args()
Mohammad's avatar
Mohammad committed
743
744
745
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
746
747
748
749
750
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
751
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
752
            writer.add_scalar('{} validation'.format(key),
753
754
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
755
            writer.add_scalar('{} validation vs samples'.format(key),
756
757
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
758
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
759
                writer.add_scalar('{} validation ppl'.format(key), ppl,
760
                                  iteration)
mohammad's avatar
mohammad committed
761
                writer.add_scalar('{} validation ppl vs samples'.format(key),
762
                                  ppl, args.consumed_train_samples)
763
764

    length = len(string) + 1
765
766
767
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
768
769


Vijay Korthikanti's avatar
Vijay Korthikanti committed
770
def cyclic_iter(iter):
771
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
772
        for x in iter:
773
774
            yield x

775
776
777
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
778
    args = get_args()
779

780
781
782
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
783
784
785

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
786
787
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
788
        args.consumed_train_samples = args.iteration * args.global_batch_size
789
    if args.iteration > 0 and args.consumed_valid_samples == 0:
790
791
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
792
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
793
            args.eval_iters * args.global_batch_size
794

795
    # Data loader only on rank 0 of each model parallel group.
796
    if mpu.get_tensor_model_parallel_rank() == 0:
797
798

        # Number of train/valid/test samples.
799
800
801
802
803
804
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
805
        test_iters = args.eval_iters
806
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
807
808
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
809
810
811
812
813
814
815
816
817
818
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
819
820
821
822
823
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
824
825
826
827
828
829
830
831
832
833
834
835
836

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
837
838
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
839
840
841
842
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
843

844
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
845
846
847
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

848
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
849
850
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
851
852
853
    else:
        train_data_iterator = None

854
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
855
856
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
857
    else:
858
        valid_data_iterator = None
859

860
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
861
862
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
863
864
865
    else:
        test_data_iterator = None

866
    return train_data_iterator, valid_data_iterator, test_data_iterator