training.py 39.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
29
from megatron import get_signal_handler
Mohammad's avatar
Mohammad committed
30
31
from megatron import get_timers
from megatron import get_tensorboard_writer
32
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
33
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
34
from megatron import is_last_rank
mohammad's avatar
mohammad committed
35
from megatron import update_num_microbatches
36
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
37
from megatron import print_rank_0
38
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
39
40
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
41
from megatron.model import Float16Module
42
from megatron.model import ModelType
mohammad's avatar
mohammad committed
43
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
44
from megatron.initialize import initialize_megatron
45
from megatron.initialize import write_args_to_tensorboard
46
47
48
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import get_forward_backward_func
53
from megatron.utils import report_memory
54
55


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
def pretrain(train_valid_test_dataset_provider,
65
             model_provider,
66
             model_type,
67
68
             forward_step_func,
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
             args_defaults={}):
70
71
72
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
73
74
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
75
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
76
        4) train the modle using the forward_step_func.
77
78

    Arguments:
79
80
81
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
82
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
83
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
84
85
86
87
88
89
90
91
92
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
93
94
    """

95
    # Initalize and get arguments, timers, and Tensorboard writer.
96
97
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
98

99
100
101
102
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
103
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
104
105
106
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
107
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
108
109
110
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

111
    args = get_args()
Mohammad's avatar
Mohammad committed
112
    timers = get_timers()
113
114

    # Model, optimizer, and learning rate.
115
    timers('model-and-optimizer-setup').start()
116
117
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
                                                               model_type)
118
    timers('model-and-optimizer-setup').stop()
119
120
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
121
122

    # Data stuff.
123
124
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
125
        all_data_iterators = [
126
127
128
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
129
130
131
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
132
133
134
135
136
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
137
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
138

139
    # >>>
140
141
142
143
144
    # from lutil import pax
    # pax({
    #     "model / len" : len(model),
    #     # "do_train": args.do_train,
    # })
145
146
    # <<<

Mohammad's avatar
Mohammad committed
147
    # Print setup timing.
148
149
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
150
    print_rank_0('training ...')
151
152

    iteration = 0
153
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
154
155
156
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
157
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
158

159
160
161
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
162
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
163
                                   iteration, False)
164
165

    if args.save and iteration != 0:
166
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
167
168
169
170
171
172

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
173
                                   0, True)
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
191
192
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
193
194
            iterations += 1
        # Reset
195
        update_num_microbatches(0, consistency_check=False)
196
197
198
199
200
201
202
203
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

204

205
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
206
    """Build the model."""
Mohammad's avatar
Mohammad committed
207
    args = get_args()
208
    args.model_type = model_type
209

210
211
212
213
214
215
216
217
    # >>>
    # from lutil import pax
    # pax({
    #     "pipeline world size" : mpu.get_pipeline_model_parallel_world_size(),
    #     "virtual size" : args.virtual_pipeline_model_parallel_size,
    # })
    # <<<

218
    # Build model.
219
220
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
221
222
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
223
224
225
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
226
227
228
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
229
            this_model = model_provider_func(
230
231
232
                pre_process=pre_process,
                post_process=post_process
            )
233
            this_model.model_type = model_type
234
            model.append(this_model)
235
        # >>>
236
237
238
239
240
        # from lutil import pax
        # pax({
        #     "virtual size" : args.virtual_pipeline_model_parallel_size,
        #     "model" : model,
        # })
241
        # <<<
242
    else:
243
244
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
270

271
272
    if not isinstance(model, list):
        model = [model]
273

274
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
275
276
277
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
278
279
280
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
281

282
283
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
284
        print(' > number of parameters on (tensor, pipeline) '
285
              'model parallel rank ({}, {}): {}'.format(
286
287
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
288
289
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
290
291

    # GPU allocation.
292
293
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
294
295

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
296
297
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
298

299
300
301
302
303
304
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
305

306
307
308
309
310
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
311
312
313
314
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
315
316
317
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
318

319
    return model
320
321


Mohammad's avatar
Mohammad committed
322
def get_learning_rate_scheduler(optimizer):
323
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
324
    args = get_args()
325

326
327
328
329
330
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
331
332
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
333
334
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
335
336
337
338
339
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
340
        update_train_iters(args)
341
342
343
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
344
345
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
346
347
        else:
            warmup_steps = args.lr_warmup_samples
348
    else:
349
350
351
        raise Exception(
            'either train-iters or train-samples should be provided.')

352
353
    lr_scheduler = AnnealingLR(
        optimizer,
354
        max_lr=args.lr,
355
        min_lr=args.min_lr,
356
357
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
358
        decay_style=args.lr_decay_style,
359
360
361
362
363
364
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


365
def setup_model_and_optimizer(model_provider_func, model_type):
366
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
367
    args = get_args()
368

369
    model = get_model(model_provider_func, model_type)
370

371
    # >>>
372
373
    # from lutil import pax
    # pax({"model": model})
374
375
    # <<<

376
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
377
                                   (torchDDP, LocalDDP, Float16Module))
378
379
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
380
    lr_scheduler = get_learning_rate_scheduler(optimizer)
381
382

    if args.load is not None:
383
384
385
386
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
387
        timers('load-checkpoint').start()
388
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
389
        torch.distributed.barrier()
390
391
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
392
393
394
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
395
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
396
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
397
398
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
399
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
400
401
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
402
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
403
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
404
405
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
406

407
408
409
    return model, optimizer, lr_scheduler


410
411
412
413
414
415
416
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
417
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
418
419
        for partition in model:
            partition.zero_grad_buffer()
420
    optimizer.zero_grad()
421

422
    forward_backward_func = get_forward_backward_func()
423
424
425
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
426

427
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
428
    if args.empty_unused_memory_level >= 1:
429
430
        torch.cuda.empty_cache()

431
432
    # All-reduce if needed.
    if args.DDP_impl == 'local':
433
        timers('backward-params-all-reduce').start()
434
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
435
            model_module.allreduce_gradients()
436
        timers('backward-params-all-reduce').stop()
437

438
439
440
441
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
442
    timers('backward-embedding-all-reduce').start()
443
    if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
444
            mpu.get_pipeline_model_parallel_world_size() > 1:
445
446
447
448
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
449
450
        else:  # We do not support the interleaved schedule for T5 yet.
            unwrapped_model = model[0]
451
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
452
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
453

454
455
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
456
457
458
459
460
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
461

Vijay Korthikanti's avatar
Vijay Korthikanti committed
462
463
464
    # All-reduce position_embeddings grad across first (encoder) and split (decoder) 
    # stages to ensure that position embeddings parameters stay in sync.
    # This should only run for T5 models with pipeline parallelism
Vijay Korthikanti's avatar
Vijay Korthikanti committed
465
466
467
468
469
470
    if mpu.is_rank_in_position_embedding_group() and \
            mpu.get_pipeline_model_parallel_world_size() > 1 and \
            args.pipeline_model_parallel_split_rank is not None:
        unwrapped_model = model[0]
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
471
472
        assert args.DDP_impl == 'local', \
            'T5 model is only supported with local DDP mode'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
473
474
        grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
        torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
475
    timers('backward-embedding-all-reduce').stop()
476

477
478
    # Update parameters.
    timers('optimizer').start()
479
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
480
481
482
    timers('optimizer').stop()

    # Update learning rate.
483
    if update_successful:
484
485
486
487
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
488
        skipped_iter = 0
489
490
491
    else:
        skipped_iter = 1

492
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
493
    if args.empty_unused_memory_level >= 2:
494
495
        torch.cuda.empty_cache()

496
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
497
498
499
500
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
501
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
502
503
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
504
505


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
506
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
507
                 loss_scale, report_memory_flag, skipped_iter,
508
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
509
510
511
512
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
513

mohammad's avatar
mohammad committed
514
515
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
516
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
517
518
519
520
521
522
523
524
525
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
526
527
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
528
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
529
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
530
    for key in loss_dict:
mohammad's avatar
mohammad committed
531
        if not skipped_iter:
532
533
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
534
535
536
537
538
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
539
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
540
541
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
542
543
544

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
545

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
546
547
548
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
549
550
551
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
552
    add_to_logging('forward-backward-send-forward-backward-recv')
553
554
555
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
556
    add_to_logging('backward-send-forward-recv')
557
    add_to_logging('backward-send-backward-recv')
558
    add_to_logging('backward-params-all-reduce')
559
    add_to_logging('backward-embedding-all-reduce')
560
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
561
    add_to_logging('optimizer-unscale-and-check-inf')
562
563
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
564
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
565
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
566

mohammad's avatar
mohammad committed
567
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
568
569
570
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
571
572
573
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
574
    # Tensorboard values.
575
576
577
578
579
580
581
582
583
584
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
585
        for key in loss_dict:
mohammad's avatar
mohammad committed
586
587
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
588
                              args.consumed_train_samples)
589
590
591
592
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
593
594
595
596
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size,
                              args.consumed_train_samples)
597
598
599
600
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
601
602
603
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
604
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
605
606
607
608
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
609
610
611
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
629
630

    if iteration % args.log_interval == 0:
631
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
632
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
633
        if writer:
634
635
636
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
637
638
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
639
        log_string += ' consumed samples: {:12d} |'.format(
640
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
641
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
642
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
643
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
644
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
645
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
646
647
648
649
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
650
651
652
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
653
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
654
655
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
656
657
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
658
659
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
660
661
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
662
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
663
664
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
665
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
666
        total_loss_dict[nan_iters_key] = 0
667
        print_rank_last(log_string)
668
669
670
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
671
672
673
674
675
676
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


677
678
679
680
681
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
682
    timers('save-checkpoint').start()
683
684
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
685
686
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
687
688


689
def train(forward_step_func, model, optimizer, lr_scheduler,
690
          train_data_iterator, valid_data_iterator):
691
    """Train the model function."""
Mohammad's avatar
Mohammad committed
692
693
    args = get_args()
    timers = get_timers()
694

695
696
697
    # Write args to tensorboard
    write_args_to_tensorboard()

698
    # Turn on training mode which enables dropout.
699
700
    for model_module in model:
        model_module.train()
701
702
703
704
705
706
707

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

708
    timers('interval-time').start()
709
    print_datetime('before the start of training step')
710
711
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
712
        update_num_microbatches(args.consumed_train_samples)
713
714
715
716
717
718
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
719
        iteration += 1
720
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
721
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
722
                                       get_num_microbatches()
723
724

        # Logging.
725
        loss_scale = optimizer.get_loss_scale().item()
726
727
728
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
729
730
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
731
                                          iteration, loss_scale,
732
                                          report_memory_flag, skipped_iter,
733
                                          grad_norm, params_norm, num_zeros_in_grad)
734
735

        # Autoresume
736
737
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
738
            check_adlr_autoresume_termination(iteration, model, optimizer,
739
                                              lr_scheduler)
740
741
742
743
744
745

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
746
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
747
                                       iteration, False)
748

749
750
        # Checkpointing
        saved_checkpoint = False
751
752
753
754
755
756
757
758
        if args.exit_signal_handler:
            signal_handler = get_signal_handler()
            if any(signal_handler.signals_received()):
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
                print_datetime('exiting program after receiving SIGTERM.')
                sys.exit()

759
760
761
762
763
764
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

765
766
767
768
769
770
771
772
773
774
775
776
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
777
                print_datetime('exiting program after {} minutes'.format(train_time))
778
779
                sys.exit()

780
        # Exiting based on iterations
781
        if args.exit_interval and iteration % args.exit_interval == 0:
782
783
784
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
785
            torch.distributed.barrier()
786
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
787
            sys.exit()
788

789

mohammad's avatar
mohammad committed
790
    return iteration
791
792


Mohammad's avatar
Mohammad committed
793
def evaluate(forward_step_func, data_iterator, model, verbose=False):
794
    """Evaluation."""
Mohammad's avatar
Mohammad committed
795
    args = get_args()
796
797

    # Turn on evaluation mode which disables dropout.
798
799
    for model_module in model:
        model_module.eval()
800
801
802
803
804
805
806
807
808
809

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
810

811
            forward_backward_func = get_forward_backward_func()
812
813
814
815
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

816
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
817
            if args.empty_unused_memory_level >= 1:
818
819
                torch.cuda.empty_cache()

820
821
822
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
823
                    for key in loss_dict:
824
825
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
826

827
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
828
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
829
                                           * get_num_microbatches()
830
    # Move model back to the train mode.
831
832
    for model_module in model:
        model_module.train()
833
834

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
835
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
836
837
838
839
840

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
841
                               iteration, verbose=False):
842
    """Helper function to evaluate and dump results on screen."""
843
    args = get_args()
Mohammad's avatar
Mohammad committed
844
845
846
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
847
848
849
850
851
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
852
        if writer:
mohammad's avatar
mohammad committed
853
            writer.add_scalar('{} validation'.format(key),
854
855
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
856
            writer.add_scalar('{} validation vs samples'.format(key),
857
858
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
859
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
860
                writer.add_scalar('{} validation ppl'.format(key), ppl,
861
                                  iteration)
mohammad's avatar
mohammad committed
862
                writer.add_scalar('{} validation ppl vs samples'.format(key),
863
                                  ppl, args.consumed_train_samples)
864
865

    length = len(string) + 1
866
867
868
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
869
870


Vijay Korthikanti's avatar
Vijay Korthikanti committed
871
def cyclic_iter(iter):
872
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
873
        for x in iter:
874
875
            yield x

876
877
878
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
879
    args = get_args()
880

881
882
883
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
884
885
886

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
887
888
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
889
        args.consumed_train_samples = args.iteration * args.global_batch_size
890
    if args.iteration > 0 and args.consumed_valid_samples == 0:
891
892
893
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
894

895
    # Data loader only on rank 0 of each model parallel group.
896
    if mpu.get_tensor_model_parallel_rank() == 0:
897
898

        # Number of train/valid/test samples.
899
900
901
902
903
904
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
905
        test_iters = args.eval_iters
906
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
907
908
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
909
910
911
912
913
914
915
916
917
918
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
919
920
921
922
923
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
924
925
926
927
928
929
930
931
932
933
934
935
936

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
937
938
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
939
940
941
942
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

943
    # >>>
944
945
    # from lutil import pax
    # pax({"hi": "there"})
946
    # <<<
Vijay Korthikanti's avatar
Vijay Korthikanti committed
947

948
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
949
950
951
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

952
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
953
954
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
955
956
957
    else:
        train_data_iterator = None

958
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
959
960
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
961
    else:
962
        valid_data_iterator = None
963

964
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
965
966
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
967
968
969
    else:
        test_data_iterator = None

970
    return train_data_iterator, valid_data_iterator, test_data_iterator