training.py 38.2 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
42
from megatron.initialize import initialize_megatron
43
from megatron.initialize import write_args_to_tensorboard
44
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
47
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
49
from megatron.utils import calc_params_l2_norm
50
from megatron.schedules import forward_backward_no_pipelining
51
from megatron.schedules import forward_backward_pipelining_without_interleaving
52
from megatron.schedules import forward_backward_pipelining_with_interleaving
53
from megatron.utils import report_memory
54
55


56
57
58
59
60
61
62
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


63
def pretrain(train_valid_test_dataset_provider,
64
             model_provider,
65
66
             forward_step_func,
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
67
             args_defaults={}):
68
69
70
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
71
72
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
73
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
74
        4) train the modle using the forward_step_func.
75
76

    Arguments:
77
78
79
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
80
81
82
83
84
85
86
87
88
89
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
90
91
    """

92
    # Initalize and get arguments, timers, and Tensorboard writer.
93
94
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
95

96
97
98
99
100
101
102
103
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
104
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
105
106
107
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

108
    args = get_args()
Mohammad's avatar
Mohammad committed
109
    timers = get_timers()
110
111

    # Model, optimizer, and learning rate.
112
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
113
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
114
    timers('model-and-optimizer-setup').stop()
115
116
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
117
118

    # Data stuff.
119
120
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
121
        all_data_iterators = [
122
123
124
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
125
126
127
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
128
129
130
131
132
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
133
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
134
135

    # Print setup timing.
136
137
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
138
    print_rank_0('training ...')
139
140

    iteration = 0
zihanl's avatar
zihanl committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    if not args.run_dialog:
        # original pre-training for GPT
        if args.do_train and args.train_iters > 0:
            iteration = train(forward_step_func,
                            model, optimizer, lr_scheduler,
                            train_data_iterator, valid_data_iterator)
        print_datetime('after training is done')

        if args.do_valid:
            prefix = 'the end of training for val data'
            evaluate_and_print_results(prefix, forward_step_func,
                                    valid_data_iterator, model,
                                    iteration, False)

        if args.save and iteration != 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)

        if args.do_test:
            # Run on test data.
            prefix = 'the end of training for test data'
            evaluate_and_print_results(prefix, forward_step_func,
                                    test_data_iterator, model,
                                    0, True)
    
    else:
        # training for dialog/control model
        timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
        for e in range(args.num_epoch):
            print_rank_0('> training on epoch %d' % (e+1))

            if args.do_train and args.train_iters > 0:
                iteration += train(forward_step_func,
                                model, optimizer, lr_scheduler,
                                train_data_iterator, valid_data_iterator)
            print_datetime('after training is done')

            if args.do_valid:
                prefix = 'the end of training for val data'
                evaluate_and_print_results(prefix, forward_step_func,
                                        valid_data_iterator, model,
                                        iteration, False)

zihanl's avatar
zihanl committed
183
184
185
186
187
188
            # if args.train_module == "dialog":
            #     if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
            #         save_checkpoint(iteration, model, optimizer, lr_scheduler)
            if args.train_module == "control":
                if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
zihanl's avatar
zihanl committed
189
190
191
192
193
194
195

            if args.do_test:
                # Run on test data.
                prefix = 'the end of training for test data'
                evaluate_and_print_results(prefix, forward_step_func,
                                        test_data_iterator, model,
                                        0, True)
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
213
214
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
215
216
            iterations += 1
        # Reset
217
        update_num_microbatches(0, consistency_check=False)
218
219
220
221
222
223
224
225
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

226

Mohammad's avatar
Mohammad committed
227
def get_model(model_provider_func):
228
    """Build the model."""
Mohammad's avatar
Mohammad committed
229
    args = get_args()
230

231
    # Build model.
232
233
234
235
236
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
237
238
239
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
240
            this_model = model_provider_func(
241
242
243
                pre_process=pre_process,
                post_process=post_process
            )
244
            model.append(this_model)
245
    else:
246
247
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
248
249
250
251
252
        model = model_provider_func(
            pre_process=pre_process,
            post_process=post_process
        )

253
254
    if not isinstance(model, list):
        model = [model]
255

256
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
257
258
259
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
260
261
262
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
263

264
265
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
266
        print(' > number of parameters on (tensor, pipeline) '
267
              'model parallel rank ({}, {}): {}'.format(
268
269
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
270
271
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
272
273

    # GPU allocation.
274
275
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
276
277

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
278
279
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
280
281
282

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
283
284
285
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
286
        return model
287

288
    if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
289
290
291
292
        model = [LocalDDP(model_module,
                          args.accumulate_allreduce_grads_in_fp32,
                          args.use_contiguous_buffers_in_ddp)
                 for model_module in model]
293
294
        return model

295
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
296
                              'Exiting.'.format(args.DDP_impl))
297
298


Mohammad's avatar
Mohammad committed
299
def get_learning_rate_scheduler(optimizer):
300
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
301
    args = get_args()
302

303
304
305
306
307
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
308
309
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
310
311
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
312
313
314
315
316
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
317
        update_train_iters(args)
318
319
320
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
321
322
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
323
324
        else:
            warmup_steps = args.lr_warmup_samples
325
    else:
326
327
328
        raise Exception(
            'either train-iters or train-samples should be provided.')

329
330
    lr_scheduler = AnnealingLR(
        optimizer,
331
        max_lr=args.lr,
332
        min_lr=args.min_lr,
333
334
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
335
        decay_style=args.lr_decay_style,
336
337
338
339
340
341
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
342
def setup_model_and_optimizer(model_provider_func):
343
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
344
    args = get_args()
345

Mohammad's avatar
Mohammad committed
346
    model = get_model(model_provider_func)
347

348
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
349
                                   (torchDDP, LocalDDP, Float16Module))
350
351
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
352
    lr_scheduler = get_learning_rate_scheduler(optimizer)
353
354

    if args.load is not None:
355
356
357
358
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
359
        timers('load-checkpoint').start()
360
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
361
362
        # need to set train_samples to None
        args.train_samples = None
363
        torch.distributed.barrier()
364
365
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
366
367
368
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
369
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
370
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
371
372
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
373
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
374
375
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
376
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
377
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
378
379
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
380

381
382
383
    return model, optimizer, lr_scheduler


384
385
386
387
388
389
390
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
391
392
393
394
395
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    else:
        optimizer.zero_grad()
396
397

    if mpu.get_pipeline_model_parallel_world_size() > 1:
398
399
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
400
401
402
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
403
        else:
404
            forward_backward_func = forward_backward_pipelining_without_interleaving
405
    else:
406
407
408
409
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
410
411
412

    # All-reduce if needed.
    if args.DDP_impl == 'local':
413
        timers('backward-params-all-reduce').start()
414
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
415
            model_module.allreduce_gradients()
416
        timers('backward-params-all-reduce').stop()
417

418
419
420
421
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
422
    timers('backward-embedding-all-reduce').start()
423
424
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
425
            mpu.get_pipeline_model_parallel_world_size() > 1:
426
427
428
429
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
430
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
431
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
432

433
434
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
435
436
437
438
439
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
440
    timers('backward-embedding-all-reduce').stop()
441

442
443
    # Update parameters.
    timers('optimizer').start()
444
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
445
446
447
    timers('optimizer').stop()

    # Update learning rate.
448
    if update_successful:
449
450
451
452
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
453
        skipped_iter = 0
454
455
456
    else:
        skipped_iter = 1

457
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
458
459
460
461
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
462
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
463
464
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
465
466


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
467
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
468
                 loss_scale, report_memory_flag, skipped_iter,
469
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
470
471
472
473
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
474

mohammad's avatar
mohammad committed
475
476
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
477
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
478
479
480
481
482
483
484
485
486
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
487
488
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
489
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
490
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
491
    for key in loss_dict:
mohammad's avatar
mohammad committed
492
        if not skipped_iter:
493
494
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
495
496
497
498
499
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
500
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
501
502
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
503
504
505

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
506

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
507
508
509
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
510
511
512
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
513
    add_to_logging('forward-backward-send-forward-backward-recv')
514
515
516
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
517
    add_to_logging('backward-send-forward-recv')
518
    add_to_logging('backward-send-backward-recv')
519
    add_to_logging('backward-params-all-reduce')
520
    add_to_logging('backward-embedding-all-reduce')
521
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
522
    add_to_logging('optimizer-unscale-and-check-inf')
523
524
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
525
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
526
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
527

mohammad's avatar
mohammad committed
528
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
529
530
531
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
532
533
534
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
535
    # Tensorboard values.
536
537
538
539
540
541
542
543
544
545
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
546
        for key in loss_dict:
mohammad's avatar
mohammad committed
547
548
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
549
                              args.consumed_train_samples)
550
551
552
553
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
554
555
556
557
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
558
559
560
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
561
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
562
563
564
565
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
566
567
568
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
569
570

    if iteration % args.log_interval == 0:
571
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
572
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
573
        if writer and torch.distributed.get_rank() == 0:
574
575
576
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
577
578
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
579
        log_string += ' consumed samples: {:12d} |'.format(
580
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
581
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
582
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
583
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
584
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
585
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
586
587
588
589
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
590
591
592
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
593
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
594
595
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
596
597
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
598
599
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
600
601
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
602
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
603
604
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
605
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
606
        total_loss_dict[nan_iters_key] = 0
607
        print_rank_last(log_string)
608
609
610
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
611
612
613
614
615
616
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


617
618
619
620
621
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
622
    timers('save-checkpoint').start()
623
624
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
625
626
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
627
628


629
def train(forward_step_func, model, optimizer, lr_scheduler,
630
          train_data_iterator, valid_data_iterator):
631
    """Train the model function."""
Mohammad's avatar
Mohammad committed
632
633
    args = get_args()
    timers = get_timers()
634

635
636
637
    # Write args to tensorboard
    write_args_to_tensorboard()

638
    # Turn on training mode which enables dropout.
639
640
    for model_module in model:
        model_module.train()
641
642
643
644
645
646
647

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

zihanl's avatar
zihanl committed
648
649
650
    if not args.run_dialog:
        timers('interval-time').start()

651
    print_datetime('before the start of training step')
652
653
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
654
        update_num_microbatches(args.consumed_train_samples)
655
656
657
658
659
660
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
661
        iteration += 1
662
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
663
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
664
                                       get_num_microbatches()
665
666

        # Logging.
667
        loss_scale = optimizer.get_loss_scale().item()
668
669
670
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
671
672
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
673
                                          iteration, loss_scale,
674
                                          report_memory_flag, skipped_iter,
675
                                          grad_norm, params_norm, num_zeros_in_grad)
676
677

        # Autoresume
678
679
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
680
            check_adlr_autoresume_termination(iteration, model, optimizer,
681
                                              lr_scheduler)
682
683
684
685
686
687

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
688
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
689
                                       iteration, False)
690

691
692
693
694
695
696
697
698
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

699
700
701
702
703
704
705
706
707
708
709
710
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
711
                print_datetime('exiting program after {} minutes'.format(train_time))
712
713
                sys.exit()

714
        # Exiting based on iterations
715
        if args.exit_interval and iteration % args.exit_interval == 0:
716
717
718
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
719
            torch.distributed.barrier()
720
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
721
            sys.exit()
722

723

mohammad's avatar
mohammad committed
724
    return iteration
725
726


Mohammad's avatar
Mohammad committed
727
def evaluate(forward_step_func, data_iterator, model, verbose=False):
728
    """Evaluation."""
Mohammad's avatar
Mohammad committed
729
    args = get_args()
730
731

    # Turn on evaluation mode which disables dropout.
732
733
    for model_module in model:
        model_module.eval()
734
735
736
737
738
739
740
741
742
743

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
744

745
746
747
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
748
                else:
749
                    forward_backward_func = forward_backward_pipelining_without_interleaving
750
751
752
753
754
755
756
757
758
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
759
                    for key in loss_dict:
760
761
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
762

763
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
764
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
765
                                           * get_num_microbatches()
766
    # Move model back to the train mode.
767
768
    for model_module in model:
        model_module.train()
769
770

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
771
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
772
773
774
775
776

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
777
                               iteration, verbose=False):
778
    """Helper function to evaluate and dump results on screen."""
779
    args = get_args()
Mohammad's avatar
Mohammad committed
780
781
782
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
783
784
785
786
787
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
788
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
789
            writer.add_scalar('{} validation'.format(key),
790
791
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
792
            writer.add_scalar('{} validation vs samples'.format(key),
793
794
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
795
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
796
                writer.add_scalar('{} validation ppl'.format(key), ppl,
797
                                  iteration)
mohammad's avatar
mohammad committed
798
                writer.add_scalar('{} validation ppl vs samples'.format(key),
799
                                  ppl, args.consumed_train_samples)
800
801

    length = len(string) + 1
802
803
804
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
805
806


Vijay Korthikanti's avatar
Vijay Korthikanti committed
807
def cyclic_iter(iter):
808
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
809
        for x in iter:
810
811
            yield x

812
813
814
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
815
    args = get_args()
816

817
818
819
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
820
821
822

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
823
824
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
825
        args.consumed_train_samples = args.iteration * args.global_batch_size
826
    if args.iteration > 0 and args.consumed_valid_samples == 0:
827
828
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
829
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
830
            args.eval_iters * args.global_batch_size
831

832
833
834
835
836
    if args.run_dialog:
        args.consumed_train_samples = 0
        args.consumed_valid_samples = 0
        args.iteration = 0

837
    # Data loader only on rank 0 of each model parallel group.
838
    if mpu.get_tensor_model_parallel_rank() == 0:
839
840
841
842
843
844
845
846
847
848
849
850
851
        
        if args.run_dialog:
            # Build the datasets.
            train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()

            print_rank_0(' > datasets target sizes:')
            train_size = len(train_ds)
            valid_size = len(valid_ds)
            test_size = len(test_ds)
            print_rank_0('    train:      {}'.format(train_size))
            print_rank_0('    validation: {}'.format(valid_size))
            print_rank_0('    test:       {}'.format(test_size))

zihanl's avatar
zihanl committed
852
            batch_size = args.global_batch_size
zihanl's avatar
zihanl committed
853
854
855
            args.train_iters = train_size // batch_size + 1
            args.eval_iters = valid_size // batch_size + 1
            args.test_iters = test_size // batch_size + 1
856

857
        else:
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
            # Number of train/valid/test samples.
            if args.train_samples:
                train_samples = args.train_samples
            else:
                train_samples = args.train_iters * args.global_batch_size
            eval_iters = (args.train_iters // args.eval_interval + 1) * \
                        args.eval_iters
            test_iters = args.eval_iters
            train_val_test_num_samples = [train_samples,
                                        eval_iters * args.global_batch_size,
                                        test_iters * args.global_batch_size]
            print_rank_0(' > datasets target sizes (minimum size):')
            print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
            print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
            print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

            # Build the datasets.
            train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
                train_val_test_num_samples)
877
878

        # Build dataloders.
879
880
881
882
883
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
884
885
886
887
888
889
890
891
892
893
894
895
896

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
897
898
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
899
900
901
902
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
903

904
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
905
906
907
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

908
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
909
910
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
911
912
913
    else:
        train_data_iterator = None

914
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
915
916
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
917
    else:
918
        valid_data_iterator = None
919

920
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
921
922
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
923
924
925
    else:
        test_data_iterator = None

926
    return train_data_iterator, valid_data_iterator, test_data_iterator