training.py 40 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
29
from megatron import get_signal_handler
Mohammad's avatar
Mohammad committed
30
31
from megatron import get_timers
from megatron import get_tensorboard_writer
32
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
33
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
34
from megatron import is_last_rank
mohammad's avatar
mohammad committed
35
from megatron import update_num_microbatches
36
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
37
from megatron import print_rank_0
38
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
39
40
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
41
from megatron.model import Float16Module
42
from megatron.model import ModelType
mohammad's avatar
mohammad committed
43
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
44
from megatron.initialize import initialize_megatron
45
from megatron.initialize import write_args_to_tensorboard
46
47
48
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import get_forward_backward_func
53
from megatron.utils import report_memory
54
55


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
def pretrain(train_valid_test_dataset_provider,
65
             model_provider,
66
             model_type,
67
             forward_step_func,
68
             process_non_loss_data_func=None,
69
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
70
             args_defaults={}):
71
72
73
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
74
75
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
76
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
77
        4) train the modle using the forward_step_func.
78
79

    Arguments:
80
81
82
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
83
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
84
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
85
86
87
88
89
90
91
92
93
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
94
95
    """

96
    # Initalize and get arguments, timers, and Tensorboard writer.
97
98
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
99

100
101
102
103
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
104
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
105
106
107
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
108
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
109
110
111
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

112
    args = get_args()
Mohammad's avatar
Mohammad committed
113
    timers = get_timers()
114
115

    # Model, optimizer, and learning rate.
116
    timers('model-and-optimizer-setup').start()
117
118
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
                                                               model_type)
119
    timers('model-and-optimizer-setup').stop()
120
121
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
122
123

    # Data stuff.
124
125
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
126
        all_data_iterators = [
127
128
129
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
130
131
132
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
133
134
135
136
137
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
138
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
139
140

    # Print setup timing.
141
142
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
143
    print_rank_0('training ...')
144
145

    iteration = 0
146
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
147
148
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
149
150
                          train_data_iterator, valid_data_iterator,
                          process_non_loss_data_func)
151
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
152

153
154
155
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
156
                                   valid_data_iterator, model,
157
158
                                   iteration, process_non_loss_data_func,
                                   False)
159
160

    if args.save and iteration != 0:
161
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
162
163
164
165
166
167

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
168
169
                                   0, process_non_loss_data_func,
                                   True)
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
187
188
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
189
190
            iterations += 1
        # Reset
191
        update_num_microbatches(0, consistency_check=False)
192
193
194
195
196
197
198
199
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

200

201
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
202
    """Build the model."""
Mohammad's avatar
Mohammad committed
203
    args = get_args()
204
    args.model_type = model_type
205

206
    # Build model.
207
208
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
209
210
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
211
212
213
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
214
215
216
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
217
            this_model = model_provider_func(
218
219
220
                pre_process=pre_process,
                post_process=post_process
            )
221
            this_model.model_type = model_type
222
            model.append(this_model)
223
    else:
224
225
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
251

252
253
    if not isinstance(model, list):
        model = [model]
254

255
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
256
257
258
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
259
260
261
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
262

263
264
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
265
        print(' > number of parameters on (tensor, pipeline) '
266
              'model parallel rank ({}, {}): {}'.format(
267
268
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
269
270
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
271
272

    # GPU allocation.
273
274
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
275
276

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
277
278
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
279

280
281
282
283
284
285
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
286

287
288
289
290
291
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
292
293
294
295
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
296
297
298
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
299

300
    return model
301
302


Mohammad's avatar
Mohammad committed
303
def get_learning_rate_scheduler(optimizer):
304
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
305
    args = get_args()
306

307
308
309
310
311
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
312
313
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
314
315
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
316
317
318
319
320
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
321
        update_train_iters(args)
322
323
324
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
325
326
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
327
328
        else:
            warmup_steps = args.lr_warmup_samples
329
    else:
330
331
332
        raise Exception(
            'either train-iters or train-samples should be provided.')

333
334
    lr_scheduler = AnnealingLR(
        optimizer,
335
        max_lr=args.lr,
336
        min_lr=args.min_lr,
337
338
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
339
        decay_style=args.lr_decay_style,
340
341
342
        start_wd=args.start_wd,
        end_wd=args.end_wd,
        wd_incr_style=args.wd_incr_style,
343
344
345
346
347
348
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


349
350
351
352
353
def setup_model_and_optimizer(model_provider_func,
                              model_type,
                              no_wd_decay_cond=None,
                              scale_lr_cond=None,
                              lr_mult=1.0):
354
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
355
    args = get_args()
356

357
    model = get_model(model_provider_func, model_type)
358

359
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
360
                                   (torchDDP, LocalDDP, Float16Module))
361
362
    optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
                                       scale_lr_cond, lr_mult)
363

Mohammad's avatar
Mohammad committed
364
    lr_scheduler = get_learning_rate_scheduler(optimizer)
365
366

    if args.load is not None:
367
368
369
370
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
371
        timers('load-checkpoint').start()
372
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
373
        torch.distributed.barrier()
374
375
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
376
377
378
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
379
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
380
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
381
382
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
383
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
384
385
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
386
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
387
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
388
389
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
390

391
392
393
    return model, optimizer, lr_scheduler


394
395
396
397
398
399
400
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
401
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
402
403
        for partition in model:
            partition.zero_grad_buffer()
404
    optimizer.zero_grad()
405

406
    forward_backward_func = get_forward_backward_func()
407
408
409
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
410

411
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
412
    if args.empty_unused_memory_level >= 1:
413
414
        torch.cuda.empty_cache()

415
416
    # All-reduce if needed.
    if args.DDP_impl == 'local':
417
        timers('backward-params-all-reduce').start()
418
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
419
            model_module.allreduce_gradients()
420
        timers('backward-params-all-reduce').stop()
421

422
423
424
425
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
426
    timers('backward-embedding-all-reduce').start()
427
    if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
428
            mpu.get_pipeline_model_parallel_world_size() > 1:
429
430
431
432
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
433
434
        else:  # We do not support the interleaved schedule for T5 yet.
            unwrapped_model = model[0]
435
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
436
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
437

438
439
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
440
441
442
443
444
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
445

Vijay Korthikanti's avatar
Vijay Korthikanti committed
446
447
448
    # All-reduce position_embeddings grad across first (encoder) and split (decoder) 
    # stages to ensure that position embeddings parameters stay in sync.
    # This should only run for T5 models with pipeline parallelism
Vijay Korthikanti's avatar
Vijay Korthikanti committed
449
450
451
452
453
454
    if mpu.is_rank_in_position_embedding_group() and \
            mpu.get_pipeline_model_parallel_world_size() > 1 and \
            args.pipeline_model_parallel_split_rank is not None:
        unwrapped_model = model[0]
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
455
456
        assert args.DDP_impl == 'local', \
            'T5 model is only supported with local DDP mode'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
457
458
        grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
        torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
459
    timers('backward-embedding-all-reduce').stop()
460

461
462
    # Update parameters.
    timers('optimizer').start()
463
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
464
465
466
    timers('optimizer').stop()

    # Update learning rate.
467
    if update_successful:
468
469
470
471
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
472
        skipped_iter = 0
473
474
475
    else:
        skipped_iter = 1

476
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
477
    if args.empty_unused_memory_level >= 2:
478
479
        torch.cuda.empty_cache()

480
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
481
482
483
484
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
485
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
486
487
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
488
489


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
490
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
491
                 loss_scale, report_memory_flag, skipped_iter,
492
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
493
494
495
496
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
497

mohammad's avatar
mohammad committed
498
499
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
500
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
501
502
503
504
505
506
507
508
509
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
510
511
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
512
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
513
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
514
    for key in loss_dict:
mohammad's avatar
mohammad committed
515
        if not skipped_iter:
516
517
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
518
519
520
521
522
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
523
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
524
525
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
526
527
528

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
529

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
530
531
532
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
533
534
535
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
536
    add_to_logging('forward-backward-send-forward-backward-recv')
537
538
539
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
540
    add_to_logging('backward-send-forward-recv')
541
    add_to_logging('backward-send-backward-recv')
542
    add_to_logging('backward-params-all-reduce')
543
    add_to_logging('backward-embedding-all-reduce')
544
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
545
    add_to_logging('optimizer-unscale-and-check-inf')
546
547
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
548
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
549
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550

mohammad's avatar
mohammad committed
551
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
552
553
554
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
555
556
557
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
558
    # Tensorboard values.
559
560
561
562
563
564
565
566
567
568
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
569
        for key in loss_dict:
mohammad's avatar
mohammad committed
570
571
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
572
                              args.consumed_train_samples)
573
574
575
576
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
577
578
579
580
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size,
                              args.consumed_train_samples)
581
582
583
584
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
585
586
587
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
588
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
589
590
591
592
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
593
594
595
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
613
614

    if iteration % args.log_interval == 0:
615
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
616
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
617
        if writer:
618
619
620
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
621
622
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
623
        log_string += ' consumed samples: {:12d} |'.format(
624
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
625
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
626
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
627
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
628
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
629
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
630
631
632
633
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
634
635
636
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
637
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
638
639
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
640
641
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
642
643
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
644
645
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
646
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
647
648
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
649
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
650
        total_loss_dict[nan_iters_key] = 0
651
        print_rank_last(log_string)
652
653
654
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
655
656
657
658
659
660
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


661
662
663
664
665
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
666
    timers('save-checkpoint').start()
667
668
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
669
670
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
671
672


673
def train(forward_step_func, model, optimizer, lr_scheduler,
674
675
          train_data_iterator, valid_data_iterator,
          process_non_loss_data_func):
676
    """Train the model function."""
Mohammad's avatar
Mohammad committed
677
678
    args = get_args()
    timers = get_timers()
679

680
681
682
    # Write args to tensorboard
    write_args_to_tensorboard()

683
    # Turn on training mode which enables dropout.
684
685
    for model_module in model:
        model_module.train()
686
687
688
689
690
691
692

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

693
    timers('interval-time').start()
694
    print_datetime('before the start of training step')
695
696
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
697
        update_num_microbatches(args.consumed_train_samples)
698
699
700
701
702
703
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
704
        iteration += 1
705
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
706
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
707
                                       get_num_microbatches()
708
709

        # Logging.
710
        loss_scale = optimizer.get_loss_scale().item()
711
712
713
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
714
715
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
716
                                          iteration, loss_scale,
717
                                          report_memory_flag, skipped_iter,
718
                                          grad_norm, params_norm, num_zeros_in_grad)
719
720

        # Autoresume
721
722
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
723
            check_adlr_autoresume_termination(iteration, model, optimizer,
724
                                              lr_scheduler)
725
726
727
728
729
730

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
731
                                       valid_data_iterator, model,
732
733
                                       iteration, process_non_loss_data_func,
                                       False)
734

735
736
        # Checkpointing
        saved_checkpoint = False
737
738
739
740
741
742
743
744
        if args.exit_signal_handler:
            signal_handler = get_signal_handler()
            if any(signal_handler.signals_received()):
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
                print_datetime('exiting program after receiving SIGTERM.')
                sys.exit()

745
746
747
748
749
750
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

751
752
753
754
755
756
757
758
759
760
761
762
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
763
                print_datetime('exiting program after {} minutes'.format(train_time))
764
765
                sys.exit()

766
        # Exiting based on iterations
767
        if args.exit_interval and iteration % args.exit_interval == 0:
768
769
770
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
771
            torch.distributed.barrier()
772
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
773
            sys.exit()
774

775

mohammad's avatar
mohammad committed
776
    return iteration
777
778


779
780
781
782
783
def evaluate(forward_step_func,
             data_iterator,
             model,
             process_non_loss_data_func,
             verbose=False):
784
    """Evaluation."""
Mohammad's avatar
Mohammad committed
785
    args = get_args()
786
787

    # Turn on evaluation mode which disables dropout.
788
789
    for model_module in model:
        model_module.eval()
790
791
792
793
794
795
796
797
798
799

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
800

801
            forward_backward_func = get_forward_backward_func()
802
803
804
805
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

806
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
807
            if args.empty_unused_memory_level >= 1:
808
809
                torch.cuda.empty_cache()

810
811
812
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
813
                    for key in loss_dict:
814
815
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
816

817
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
818
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
819
                                           * get_num_microbatches()
820
821
822
823
824
825
        collected_non_loss_data = None
        if process_non_loss_data_func is not None and is_last_rank():
            collected_non_loss_data = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True, collect_non_loss_data=True)

826
    # Move model back to the train mode.
827
828
    for model_module in model:
        model_module.train()
829
830

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
831
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
832

833
    return total_loss_dict, collected_non_loss_data
834
835
836

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
837
838
                               iteration, process_non_loss_data_func,
                               verbose=False):
839
    """Helper function to evaluate and dump results on screen."""
840
    args = get_args()
Mohammad's avatar
Mohammad committed
841
842
    writer = get_tensorboard_writer()

843
844
845
    total_loss_dict, collected_non_loss_data = evaluate(
        forward_step_func, data_iterator, model,
        process_non_loss_data_func, verbose)
846
847
848
849
850
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
851
        if writer:
mohammad's avatar
mohammad committed
852
            writer.add_scalar('{} validation'.format(key),
853
854
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
855
            writer.add_scalar('{} validation vs samples'.format(key),
856
857
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
858
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
859
                writer.add_scalar('{} validation ppl'.format(key), ppl,
860
                                  iteration)
mohammad's avatar
mohammad committed
861
                writer.add_scalar('{} validation ppl vs samples'.format(key),
862
                                  ppl, args.consumed_train_samples)
863

864
865
866
    if process_non_loss_data_func is not None and writer and is_last_rank():
        process_non_loss_data_func(collected_non_loss_data, iteration, writer)

867
    length = len(string) + 1
868
869
870
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
871
872


Vijay Korthikanti's avatar
Vijay Korthikanti committed
873
def cyclic_iter(iter):
874
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
875
        for x in iter:
876
877
            yield x

878
879
880
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
881
    args = get_args()
882

883
884
885
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
886
887
888

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
889
890
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
891
        args.consumed_train_samples = args.iteration * args.global_batch_size
892
    if args.iteration > 0 and args.consumed_valid_samples == 0:
893
894
895
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
896

897
    # Data loader only on rank 0 of each model parallel group.
898
    if mpu.get_tensor_model_parallel_rank() == 0:
899
900

        # Number of train/valid/test samples.
901
902
903
904
905
906
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
907
        test_iters = args.eval_iters
908
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
909
910
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
911
912
913
914
915
916
917
918
919
920
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
921
922
923
924
925
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
926
927
928
929
930
931
932
933
934
935
936
937
938

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
939
940
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
941
942
943
944
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
945

946
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
947
948
949
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

950
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
951
952
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
953
954
955
    else:
        train_data_iterator = None

956
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
957
958
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
959
    else:
960
        valid_data_iterator = None
961

962
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
963
964
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
965
966
967
    else:
        test_data_iterator = None

968
    return train_data_iterator, valid_data_iterator, test_data_iterator