training.py 40.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
29
from megatron import get_signal_handler
Mohammad's avatar
Mohammad committed
30
31
from megatron import get_timers
from megatron import get_tensorboard_writer
32
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
33
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
34
from megatron import is_last_rank
mohammad's avatar
mohammad committed
35
from megatron import update_num_microbatches
36
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
37
from megatron import print_rank_0
38
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
39
40
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
41
from megatron.model import Float16Module
42
from megatron.model import ModelType
mohammad's avatar
mohammad committed
43
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
44
from megatron.initialize import initialize_megatron
45
from megatron.initialize import write_args_to_tensorboard
46
47
48
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import get_forward_backward_func
53
from megatron.utils import report_memory
54
55


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
def pretrain(train_valid_test_dataset_provider,
65
             model_provider,
66
             model_type,
67
             forward_step_func,
68
             process_non_loss_data_func=None,
69
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
70
             args_defaults={}):
71
72
73
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
74
75
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
76
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
77
        4) train the modle using the forward_step_func.
78
79

    Arguments:
80
81
82
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
83
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
84
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
85
86
87
88
89
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
91
92
93
        process_non_loss_data_func: a function to post process outputs of the
            network. It can be used for dumping output tensors (e.g images) to
            tensorboard. It takes `collected data`(list of tensors),
            `current iteration index` and `tensorboard writer` as arguments.
Mohammad's avatar
Mohammad committed
94
95
96
97
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
98
99
    """

100
    # Initalize and get arguments, timers, and Tensorboard writer.
101
102
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
103

104
105
106
107
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
108
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
109
110
111
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
112
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
113
114
115
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

116
    args = get_args()
Mohammad's avatar
Mohammad committed
117
    timers = get_timers()
118
119

    # Model, optimizer, and learning rate.
120
    timers('model-and-optimizer-setup').start()
121
122
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
                                                               model_type)
123
    timers('model-and-optimizer-setup').stop()
124
125
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
126
127

    # Data stuff.
128
129
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
130
        all_data_iterators = [
131
132
133
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
134
135
136
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
137
138
139
140
141
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
142
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
143
144

    # Print setup timing.
145
146
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
147
    print_rank_0('training ...')
148
149

    iteration = 0
150
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
151
152
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
153
154
                          train_data_iterator, valid_data_iterator,
                          process_non_loss_data_func)
155
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
156

157
158
159
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
160
                                   valid_data_iterator, model,
161
162
                                   iteration, process_non_loss_data_func,
                                   False)
163
164

    if args.save and iteration != 0:
165
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
166
167
168
169
170
171

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
172
173
                                   0, process_non_loss_data_func,
                                   True)
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
191
192
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
193
194
            iterations += 1
        # Reset
195
        update_num_microbatches(0, consistency_check=False)
196
197
198
199
200
201
202
203
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

204

205
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
206
    """Build the model."""
Mohammad's avatar
Mohammad committed
207
    args = get_args()
208
    args.model_type = model_type
209

210
    # Build model.
211
212
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
213
214
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
215
216
217
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
218
219
220
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
221
            this_model = model_provider_func(
222
223
224
                pre_process=pre_process,
                post_process=post_process
            )
225
            this_model.model_type = model_type
226
            model.append(this_model)
227
    else:
228
229
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
255

256
257
    if not isinstance(model, list):
        model = [model]
258

259
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
260
261
262
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
263
264
265
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
266

267
268
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
269
        print(' > number of parameters on (tensor, pipeline) '
270
              'model parallel rank ({}, {}): {}'.format(
271
272
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
273
274
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
275
276

    # GPU allocation.
277
278
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
279
280

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
281
282
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
283

284
285
286
287
288
289
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
290

291
292
293
294
295
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
296
297
298
299
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
300
301
302
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
303

304
    return model
305
306


Mohammad's avatar
Mohammad committed
307
def get_learning_rate_scheduler(optimizer):
308
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
309
    args = get_args()
310

311
312
313
314
315
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
316
317
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
318
319
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
320
321
322
323
324
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
325
        update_train_iters(args)
326
327
328
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
329
330
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
331
332
        else:
            warmup_steps = args.lr_warmup_samples
333
    else:
334
335
336
        raise Exception(
            'either train-iters or train-samples should be provided.')

337
338
    lr_scheduler = AnnealingLR(
        optimizer,
339
        max_lr=args.lr,
340
        min_lr=args.min_lr,
341
342
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
343
        decay_style=args.lr_decay_style,
344
345
346
        start_wd=args.start_wd,
        end_wd=args.end_wd,
        wd_incr_style=args.wd_incr_style,
347
348
349
350
351
352
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


353
354
355
356
357
def setup_model_and_optimizer(model_provider_func,
                              model_type,
                              no_wd_decay_cond=None,
                              scale_lr_cond=None,
                              lr_mult=1.0):
358
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
359
    args = get_args()
360

361
    model = get_model(model_provider_func, model_type)
362

363
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
364
                                   (torchDDP, LocalDDP, Float16Module))
365
366
    optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
                                       scale_lr_cond, lr_mult)
367

Mohammad's avatar
Mohammad committed
368
    lr_scheduler = get_learning_rate_scheduler(optimizer)
369
370

    if args.load is not None:
371
372
373
374
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
375
        timers('load-checkpoint').start()
376
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
377
        torch.distributed.barrier()
378
379
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
380
381
382
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
383
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
384
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
385
386
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
387
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
388
389
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
390
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
391
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
392
393
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
394

395
396
397
    return model, optimizer, lr_scheduler


398
399
400
401
402
403
404
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
405
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
406
407
        for partition in model:
            partition.zero_grad_buffer()
408
    optimizer.zero_grad()
409

410
    forward_backward_func = get_forward_backward_func()
411
412
413
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
414

415
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
416
    if args.empty_unused_memory_level >= 1:
417
418
        torch.cuda.empty_cache()

419
420
    # All-reduce if needed.
    if args.DDP_impl == 'local':
421
        timers('backward-params-all-reduce').start()
422
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
423
            model_module.allreduce_gradients()
424
        timers('backward-params-all-reduce').stop()
425

426
427
428
429
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
430
    timers('backward-embedding-all-reduce').start()
431
    if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
432
            mpu.get_pipeline_model_parallel_world_size() > 1:
433
434
435
436
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
437
438
        else:  # We do not support the interleaved schedule for T5 yet.
            unwrapped_model = model[0]
439
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
440
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
441

442
443
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
444
445
446
447
448
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
449

Vijay Korthikanti's avatar
Vijay Korthikanti committed
450
451
452
    # All-reduce position_embeddings grad across first (encoder) and split (decoder) 
    # stages to ensure that position embeddings parameters stay in sync.
    # This should only run for T5 models with pipeline parallelism
Vijay Korthikanti's avatar
Vijay Korthikanti committed
453
454
455
456
457
458
    if mpu.is_rank_in_position_embedding_group() and \
            mpu.get_pipeline_model_parallel_world_size() > 1 and \
            args.pipeline_model_parallel_split_rank is not None:
        unwrapped_model = model[0]
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
459
460
        assert args.DDP_impl == 'local', \
            'T5 model is only supported with local DDP mode'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
461
462
        grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
        torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
463
    timers('backward-embedding-all-reduce').stop()
464

465
466
    # Update parameters.
    timers('optimizer').start()
467
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
468
469
470
    timers('optimizer').stop()

    # Update learning rate.
471
    if update_successful:
472
473
474
475
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
476
        skipped_iter = 0
477
478
479
    else:
        skipped_iter = 1

480
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
481
    if args.empty_unused_memory_level >= 2:
482
483
        torch.cuda.empty_cache()

484
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
485
486
487
488
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
489
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
490
491
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
492
493


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
494
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
495
                 loss_scale, report_memory_flag, skipped_iter,
496
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
497
498
499
500
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
501

mohammad's avatar
mohammad committed
502
503
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
504
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
505
506
507
508
509
510
511
512
513
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
514
515
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
516
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
517
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
518
    for key in loss_dict:
mohammad's avatar
mohammad committed
519
        if not skipped_iter:
520
521
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
522
523
524
525
526
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
527
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
528
529
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
530
531
532

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
533

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
534
535
536
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
537
538
539
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
540
    add_to_logging('forward-backward-send-forward-backward-recv')
541
542
543
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
544
    add_to_logging('backward-send-forward-recv')
545
    add_to_logging('backward-send-backward-recv')
546
    add_to_logging('backward-params-all-reduce')
547
    add_to_logging('backward-embedding-all-reduce')
548
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
549
    add_to_logging('optimizer-unscale-and-check-inf')
550
551
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
552
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
553
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
554

mohammad's avatar
mohammad committed
555
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
556
557
558
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
559
560
561
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
562
    # Tensorboard values.
563
564
565
566
567
568
569
570
571
572
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
573
        for key in loss_dict:
mohammad's avatar
mohammad committed
574
575
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
576
                              args.consumed_train_samples)
577
578
579
580
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
581
582
583
584
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size,
                              args.consumed_train_samples)
585
586
587
588
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
589
590
591
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
592
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
593
594
595
596
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
597
598
599
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
617
618

    if iteration % args.log_interval == 0:
619
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
620
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
621
        if writer:
622
623
624
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
625
626
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
627
        log_string += ' consumed samples: {:12d} |'.format(
628
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
629
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
630
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
631
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
632
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
633
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
634
635
636
637
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
638
639
640
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
641
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
642
643
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
644
645
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
646
647
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
648
649
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
650
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
651
652
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
653
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
654
        total_loss_dict[nan_iters_key] = 0
655
        print_rank_last(log_string)
656
657
658
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
659
660
661
662
663
664
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


665
666
667
668
669
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
670
    timers('save-checkpoint').start()
671
672
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
673
674
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
675
676


677
def train(forward_step_func, model, optimizer, lr_scheduler,
678
679
          train_data_iterator, valid_data_iterator,
          process_non_loss_data_func):
680
    """Train the model function."""
Mohammad's avatar
Mohammad committed
681
682
    args = get_args()
    timers = get_timers()
683

684
685
686
    # Write args to tensorboard
    write_args_to_tensorboard()

687
    # Turn on training mode which enables dropout.
688
689
    for model_module in model:
        model_module.train()
690
691
692
693
694
695
696

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

697
    timers('interval-time').start()
698
    print_datetime('before the start of training step')
699
700
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
701
        update_num_microbatches(args.consumed_train_samples)
702
703
704
705
706
707
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
708
        iteration += 1
709
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
710
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
711
                                       get_num_microbatches()
712
713

        # Logging.
714
        loss_scale = optimizer.get_loss_scale().item()
715
716
717
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
718
719
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
720
                                          iteration, loss_scale,
721
                                          report_memory_flag, skipped_iter,
722
                                          grad_norm, params_norm, num_zeros_in_grad)
723
724

        # Autoresume
725
726
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
727
            check_adlr_autoresume_termination(iteration, model, optimizer,
728
                                              lr_scheduler)
729
730
731
732
733
734

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
735
                                       valid_data_iterator, model,
736
737
                                       iteration, process_non_loss_data_func,
                                       False)
738

739
740
        # Checkpointing
        saved_checkpoint = False
741
742
743
744
745
746
747
748
        if args.exit_signal_handler:
            signal_handler = get_signal_handler()
            if any(signal_handler.signals_received()):
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
                print_datetime('exiting program after receiving SIGTERM.')
                sys.exit()

749
750
751
752
753
754
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

755
756
757
758
759
760
761
762
763
764
765
766
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
767
                print_datetime('exiting program after {} minutes'.format(train_time))
768
769
                sys.exit()

770
        # Exiting based on iterations
771
        if args.exit_interval and iteration % args.exit_interval == 0:
772
773
774
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
775
            torch.distributed.barrier()
776
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
777
            sys.exit()
778

779

mohammad's avatar
mohammad committed
780
    return iteration
781
782


783
784
785
786
787
def evaluate(forward_step_func,
             data_iterator,
             model,
             process_non_loss_data_func,
             verbose=False):
788
    """Evaluation."""
Mohammad's avatar
Mohammad committed
789
    args = get_args()
790
791

    # Turn on evaluation mode which disables dropout.
792
793
    for model_module in model:
        model_module.eval()
794
795
796
797
798
799
800
801
802
803

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
804

805
            forward_backward_func = get_forward_backward_func()
806
807
808
809
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

810
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
811
            if args.empty_unused_memory_level >= 1:
812
813
                torch.cuda.empty_cache()

814
815
816
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
817
                    for key in loss_dict:
818
819
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
820

821
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
822
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
823
                                           * get_num_microbatches()
824
825
826
827
828
829
        collected_non_loss_data = None
        if process_non_loss_data_func is not None and is_last_rank():
            collected_non_loss_data = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True, collect_non_loss_data=True)

830
    # Move model back to the train mode.
831
832
    for model_module in model:
        model_module.train()
833
834

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
835
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
836

837
    return total_loss_dict, collected_non_loss_data
838
839
840

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
841
842
                               iteration, process_non_loss_data_func,
                               verbose=False):
843
    """Helper function to evaluate and dump results on screen."""
844
    args = get_args()
Mohammad's avatar
Mohammad committed
845
846
    writer = get_tensorboard_writer()

847
848
849
    total_loss_dict, collected_non_loss_data = evaluate(
        forward_step_func, data_iterator, model,
        process_non_loss_data_func, verbose)
850
851
852
853
854
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
855
        if writer:
mohammad's avatar
mohammad committed
856
            writer.add_scalar('{} validation'.format(key),
857
858
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
859
            writer.add_scalar('{} validation vs samples'.format(key),
860
861
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
862
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
863
                writer.add_scalar('{} validation ppl'.format(key), ppl,
864
                                  iteration)
mohammad's avatar
mohammad committed
865
                writer.add_scalar('{} validation ppl vs samples'.format(key),
866
                                  ppl, args.consumed_train_samples)
867

868
869
870
    if process_non_loss_data_func is not None and writer and is_last_rank():
        process_non_loss_data_func(collected_non_loss_data, iteration, writer)

871
    length = len(string) + 1
872
873
874
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
875
876


Vijay Korthikanti's avatar
Vijay Korthikanti committed
877
def cyclic_iter(iter):
878
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
879
        for x in iter:
880
881
            yield x

882
883
884
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
885
    args = get_args()
886

887
888
889
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
890
891
892

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
893
894
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
895
        args.consumed_train_samples = args.iteration * args.global_batch_size
896
    if args.iteration > 0 and args.consumed_valid_samples == 0:
897
898
899
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
900

901
    # Data loader only on rank 0 of each model parallel group.
902
    if mpu.get_tensor_model_parallel_rank() == 0:
903
904

        # Number of train/valid/test samples.
905
906
907
908
909
910
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
911
        test_iters = args.eval_iters
912
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
913
914
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
915
916
917
918
919
920
921
922
923
924
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
925
926
927
928
929
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
930
931
932
933
934
935
936
937
938
939
940
941
942

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
943
944
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
945
946
947
948
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
949

950
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
951
952
953
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

954
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
955
956
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
957
958
959
    else:
        train_data_iterator = None

960
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
961
962
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
963
    else:
964
        valid_data_iterator = None
965

966
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
967
968
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
969
970
971
    else:
        test_data_iterator = None

972
    return train_data_iterator, valid_data_iterator, test_data_iterator