training.py 41.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
24
25
26
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
27
from megatron import get_args
28
from megatron import get_signal_handler
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
41
from megatron.model import ModelType
mohammad's avatar
mohammad committed
42
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
from megatron.initialize import warmup_jit_function
46
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
47
48
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import get_forward_backward_func
53
from megatron.utils import report_memory
54
from megatron.model.vision.knn_monitor import compute_feature_bank
55

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
def pretrain(train_valid_test_dataset_provider,
65
             model_provider,
66
             model_type,
67
             forward_step_func,
68
             process_non_loss_data_func=None,
69
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
70
             args_defaults={}):
71
72
73
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
74
75
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
76
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
77
        4) train the modle using the forward_step_func.
78
79

    Arguments:
80
81
82
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
83
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
84
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
85
86
87
88
89
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
91
92
93
        process_non_loss_data_func: a function to post process outputs of the
            network. It can be used for dumping output tensors (e.g images) to
            tensorboard. It takes `collected data`(list of tensors),
            `current iteration index` and `tensorboard writer` as arguments.
Mohammad's avatar
Mohammad committed
94
95
96
97
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
98
99
    """

100
    # Initalize and get arguments, timers, and Tensorboard writer.
101
102
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
103
    warmup_jit_function()
104

105
106
107
108
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
109
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
110
111
112
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
113
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
114
115
116
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

117
    args = get_args()
Mohammad's avatar
Mohammad committed
118
    timers = get_timers()
119
120

    # Model, optimizer, and learning rate.
121
    timers('model-and-optimizer-setup').start()
122
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
123
                                                               model_type)
124
    timers('model-and-optimizer-setup').stop()
125
126
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
127
128

    # Data stuff.
129
130
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
131
        all_data_iterators = [
132
133
134
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
135
136
137
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
138
139
140
141
142
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
143
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
144
145

    # Print setup timing.
146
147
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
148
    print_rank_0('training ...')
149
150

    iteration = 0
151
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
152
        iteration = train(forward_step_func,
153
                          model, optimizer, opt_param_scheduler,
154
155
                          train_data_iterator, valid_data_iterator,
                          process_non_loss_data_func)
156
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
157

158
159
160
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
161
                                   valid_data_iterator, model,
162
163
                                   iteration, process_non_loss_data_func,
                                   False)
164
165

    if args.save and iteration != 0:
166
        save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
167
168
169
170
171
172

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
173
174
                                   0, process_non_loss_data_func,
                                   True)
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
192
193
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
194
195
            iterations += 1
        # Reset
196
        update_num_microbatches(0, consistency_check=False)
197
198
199
200
201
202
203
204
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

205

206
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
207
    """Build the model."""
Mohammad's avatar
Mohammad committed
208
    args = get_args()
209
    args.model_type = model_type
210

211
    # Build model.
212
213
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
214
215
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
216
217
218
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
219
220
221
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
222
            this_model = model_provider_func(
223
224
225
                pre_process=pre_process,
                post_process=post_process
            )
226
            this_model.model_type = model_type
227
            model.append(this_model)
228
    else:
229
230
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
256

257
258
    if not isinstance(model, list):
        model = [model]
259

260
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
261
262
263
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
264
265
266
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
267

268
269
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
270
        print(' > number of parameters on (tensor, pipeline) '
271
              'model parallel rank ({}, {}): {}'.format(
272
273
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
274
275
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
276
277

    # GPU allocation.
278
279
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
280
281

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
282
283
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
284

285
286
287
288
289
290
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
291

292
293
294
295
296
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
297
298
299
300
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
301
302
303
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
304

305
    return model
306
307


308
def get_optimizer_param_scheduler(optimizer):
309
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
310
    args = get_args()
311

312
313
314
315
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
Vijay Korthikanti's avatar
Vijay Korthikanti committed
316
317
        lr_decay_steps = args.lr_decay_iters * args.global_batch_size
        wd_incr_steps = args.train_iters * args.global_batch_size
318
        if args.lr_warmup_fraction is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
319
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
320
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
321
            lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
322
323
324
325
326
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
327
        update_train_iters(args)
328
329
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
Vijay Korthikanti's avatar
Vijay Korthikanti committed
330
331
        lr_decay_steps = args.lr_decay_samples
        wd_incr_steps = args.train_samples
332
        if args.lr_warmup_fraction is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
333
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
334
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
335
            lr_warmup_steps = args.lr_warmup_samples
336
    else:
337
338
339
        raise Exception(
            'either train-iters or train-samples should be provided.')

340
    opt_param_scheduler = OptimizerParamScheduler(
341
        optimizer,
342
        max_lr=args.lr,
343
        min_lr=args.min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
344
345
346
        lr_warmup_steps=lr_warmup_steps,
        lr_decay_steps=lr_decay_steps,
        lr_decay_style=args.lr_decay_style,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
347
348
        start_wd=args.start_weight_decay,
        end_wd=args.end_weight_decay,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
349
        wd_incr_steps=wd_incr_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
350
        wd_incr_style=args.weight_decay_incr_style,
351
352
        use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
        override_opt_param_scheduler=args.override_opt_param_scheduler)
353

354
    return opt_param_scheduler
355
356


357
358
359
360
361
def setup_model_and_optimizer(model_provider_func,
                              model_type,
                              no_wd_decay_cond=None,
                              scale_lr_cond=None,
                              lr_mult=1.0):
362
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
363
    args = get_args()
364

365
    model = get_model(model_provider_func, model_type)
366

367
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
368
                                   (torchDDP, LocalDDP, Float16Module))
369
370
    optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
                                       scale_lr_cond, lr_mult)
371

372
    opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
373
374

    if args.load is not None:
375
376
377
378
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
379
        timers('load-checkpoint').start()
380
        args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
381
        torch.distributed.barrier()
382
383
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
384
385
386
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
387
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
388
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
389
390
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
391
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
392
393
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
394
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
395
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
396
397
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
398

399
    return model, optimizer, opt_param_scheduler
400
401


402
def train_step(forward_step_func, data_iterator,
403
               model, optimizer, opt_param_scheduler):
404
405
406
407
408
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
409
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
410
411
        for partition in model:
            partition.zero_grad_buffer()
412
    optimizer.zero_grad()
413

414
    forward_backward_func = get_forward_backward_func()
415
416
417
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
418

419
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
420
    if args.empty_unused_memory_level >= 1:
421
422
        torch.cuda.empty_cache()

423
424
    # All-reduce if needed.
    if args.DDP_impl == 'local':
425
        timers('backward-params-all-reduce').start()
426
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
427
            model_module.allreduce_gradients()
428
        timers('backward-params-all-reduce').stop()
429

430
431
432
433
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
434
    timers('backward-embedding-all-reduce').start()
435
    if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
436
            mpu.get_pipeline_model_parallel_world_size() > 1:
437
438
439
440
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
441
442
        else:  # We do not support the interleaved schedule for T5 yet.
            unwrapped_model = model[0]
443
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
444
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
445

446
447
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
448
449
450
451
452
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
453

Vijay Korthikanti's avatar
Vijay Korthikanti committed
454
455
456
    # All-reduce position_embeddings grad across first (encoder) and split (decoder) 
    # stages to ensure that position embeddings parameters stay in sync.
    # This should only run for T5 models with pipeline parallelism
Vijay Korthikanti's avatar
Vijay Korthikanti committed
457
458
459
460
461
462
    if mpu.is_rank_in_position_embedding_group() and \
            mpu.get_pipeline_model_parallel_world_size() > 1 and \
            args.pipeline_model_parallel_split_rank is not None:
        unwrapped_model = model[0]
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
463
464
        assert args.DDP_impl == 'local', \
            'T5 model is only supported with local DDP mode'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
465
466
        grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
        torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
467
    timers('backward-embedding-all-reduce').stop()
468

Vijay Korthikanti's avatar
Vijay Korthikanti committed
469
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
470
471
472
473
474
        unwrapped_model = unwrap_model(model[0],
                                       (torchDDP, LocalDDP, Float16Module))
        unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)


475
476
    # Update parameters.
    timers('optimizer').start()
477
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
478
479
    timers('optimizer').stop()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
480
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
481
482
483
484
485
        unwrapped_model = unwrap_model(model[0],
                                       (torchDDP, LocalDDP, Float16Module))
        unwrapped_model.update_momentum(args.curr_iteration)


486
    # Update learning rate.
487
    if update_successful:
488
489
490
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
491
        opt_param_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
492
        skipped_iter = 0
493
494
495
    else:
        skipped_iter = 1

496
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
497
    if args.empty_unused_memory_level >= 2:
498
499
        torch.cuda.empty_cache()

500
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
501
502
503
504
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
505
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
506
507
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
508
509


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
510
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
511
                 loss_scale, report_memory_flag, skipped_iter,
512
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
513
514
515
516
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
517

mohammad's avatar
mohammad committed
518
519
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
520
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
521
522
523
524
525
526
527
528
529
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
530
531
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
532
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
533
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
534
    for key in loss_dict:
mohammad's avatar
mohammad committed
535
        if not skipped_iter:
536
537
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
538
539
540
541
542
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
543
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
544
545
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
546
547
548

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
549

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550
551
552
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
553
554
555
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
556
    add_to_logging('forward-backward-send-forward-backward-recv')
557
558
559
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
560
    add_to_logging('backward-send-forward-recv')
561
    add_to_logging('backward-send-backward-recv')
562
    add_to_logging('backward-params-all-reduce')
563
    add_to_logging('backward-embedding-all-reduce')
564
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
565
    add_to_logging('optimizer-unscale-and-check-inf')
566
567
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
568
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
569
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
570

mohammad's avatar
mohammad committed
571
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
572
573
574
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
575
576
577
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
578
    # Tensorboard values.
579
580
581
582
583
584
585
586
587
588
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
589
        for key in loss_dict:
mohammad's avatar
mohammad committed
590
591
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
592
                              args.consumed_train_samples)
593
594
595
596
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
597
598
599
600
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size,
                              args.consumed_train_samples)
601
602
603
604
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
605
606
607
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
608
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
609
610
611
612
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
613
614
615
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
633
634

    if iteration % args.log_interval == 0:
635
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
636
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
637
        if writer:
638
639
640
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
641
642
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
643
        log_string += ' consumed samples: {:12d} |'.format(
644
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
645
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
646
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
647
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
648
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
649
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
650
651
652
653
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
654
655
656
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
657
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
658
659
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
660
661
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
662
663
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
664
665
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
666
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
667
668
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
669
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
670
        total_loss_dict[nan_iters_key] = 0
671
        print_rank_last(log_string)
672
673
674
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
675
676
677
678
679
680
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


681
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
682
683
684
685
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
686
    timers('save-checkpoint').start()
687
    save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
688
    torch.distributed.barrier()
689
690
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
691
692


693
def train(forward_step_func, model, optimizer, opt_param_scheduler,
694
695
          train_data_iterator, valid_data_iterator,
          process_non_loss_data_func):
696
    """Train the model function."""
Mohammad's avatar
Mohammad committed
697
698
    args = get_args()
    timers = get_timers()
699

700
701
702
    # Write args to tensorboard
    write_args_to_tensorboard()

703
    # Turn on training mode which enables dropout.
704
705
    for model_module in model:
        model_module.train()
706
707
708
709
710
711
712

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

713
    timers('interval-time').start()
714
    print_datetime('before the start of training step')
715
716
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
717
        update_num_microbatches(args.consumed_train_samples)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
718
        args.curr_iteration = iteration
719
720
721
722
723
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
724
                       opt_param_scheduler)
725
        iteration += 1
726
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
727
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
728
                                       get_num_microbatches()
729
730

        # Logging.
731
        loss_scale = optimizer.get_loss_scale().item()
732
733
734
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
735
736
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
737
                                          iteration, loss_scale,
738
                                          report_memory_flag, skipped_iter,
739
                                          grad_norm, params_norm, num_zeros_in_grad)
740
741

        # Autoresume
742
743
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
744
            check_adlr_autoresume_termination(iteration, model, optimizer,
745
                                              opt_param_scheduler)
746
747
748
749
750
751

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
752
                                       valid_data_iterator, model,
753
754
                                       iteration, process_non_loss_data_func,
                                       False)
755

756
757
        # Checkpointing
        saved_checkpoint = False
758
759
760
761
        if args.exit_signal_handler:
            signal_handler = get_signal_handler()
            if any(signal_handler.signals_received()):
                save_checkpoint_and_time(iteration, model, optimizer,
762
                                         opt_param_scheduler)
763
764
765
                print_datetime('exiting program after receiving SIGTERM.')
                sys.exit()

766
767
768
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
769
                                     opt_param_scheduler)
770
771
            saved_checkpoint = True

772
773
774
775
776
777
778
779
780
781
782
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
783
                                             opt_param_scheduler)
784
                print_datetime('exiting program after {} minutes'.format(train_time))
785
786
                sys.exit()

787
        # Exiting based on iterations
788
        if args.exit_interval and iteration % args.exit_interval == 0:
789
790
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
791
                                         opt_param_scheduler)
792
            torch.distributed.barrier()
793
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
794
            sys.exit()
795

796

mohammad's avatar
mohammad committed
797
    return iteration
798
799


800
801
802
803
804
def evaluate(forward_step_func,
             data_iterator,
             model,
             process_non_loss_data_func,
             verbose=False):
805
    """Evaluation."""
Mohammad's avatar
Mohammad committed
806
    args = get_args()
807

Vijay Korthikanti's avatar
Vijay Korthikanti committed
808
809
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        compute_feature_bank(model)
810

811
    # Turn on evaluation mode which disables dropout.
812
813
    for model_module in model:
        model_module.eval()
814
815
816
817
818
819
820
821
822
823

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
824

825
            forward_backward_func = get_forward_backward_func()
826
827
828
829
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

830
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
831
            if args.empty_unused_memory_level >= 1:
832
833
                torch.cuda.empty_cache()

834
835
836
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
837
                    for key in loss_dict:
838
839
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
840

841
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
842
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
843
                                           * get_num_microbatches()
844
845
846
847
848
849
        collected_non_loss_data = None
        if process_non_loss_data_func is not None and is_last_rank():
            collected_non_loss_data = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True, collect_non_loss_data=True)

850
    # Move model back to the train mode.
851
852
    for model_module in model:
        model_module.train()
853
854

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
855
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
856

857
    return total_loss_dict, collected_non_loss_data
858
859
860

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
861
862
                               iteration, process_non_loss_data_func,
                               verbose=False):
863
    """Helper function to evaluate and dump results on screen."""
864
    args = get_args()
Mohammad's avatar
Mohammad committed
865
866
    writer = get_tensorboard_writer()

867
868
869
    total_loss_dict, collected_non_loss_data = evaluate(
        forward_step_func, data_iterator, model,
        process_non_loss_data_func, verbose)
870
871
872
873
874
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
875
        if writer:
mohammad's avatar
mohammad committed
876
            writer.add_scalar('{} validation'.format(key),
877
878
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
879
            writer.add_scalar('{} validation vs samples'.format(key),
880
881
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
882
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
883
                writer.add_scalar('{} validation ppl'.format(key), ppl,
884
                                  iteration)
mohammad's avatar
mohammad committed
885
                writer.add_scalar('{} validation ppl vs samples'.format(key),
886
                                  ppl, args.consumed_train_samples)
887

888
889
890
    if process_non_loss_data_func is not None and writer and is_last_rank():
        process_non_loss_data_func(collected_non_loss_data, iteration, writer)

891
    length = len(string) + 1
892
893
894
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
895
896


Vijay Korthikanti's avatar
Vijay Korthikanti committed
897
def cyclic_iter(iter):
898
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
899
        for x in iter:
900
901
            yield x

902
903
904
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
905
    args = get_args()
906

907
908
909
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
910
911
912

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
913
914
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
915
        args.consumed_train_samples = args.iteration * args.global_batch_size
916
    if args.iteration > 0 and args.consumed_valid_samples == 0:
917
918
919
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
920

921
    # Data loader only on rank 0 of each model parallel group.
922
    if mpu.get_tensor_model_parallel_rank() == 0:
923
924

        # Number of train/valid/test samples.
925
926
927
928
929
930
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
931
        test_iters = args.eval_iters
932
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
933
934
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
935
936
937
938
939
940
941
942
943
944
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
945
946
947
948
949
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
950
951
952
953
954
955
956
957
958
959
960
961
962

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
963
964
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
965
966
967
968
969
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
970
971
972
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

973
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
974
975
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
976
977
978
    else:
        train_data_iterator = None

979
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
980
981
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
982
    else:
983
        valid_data_iterator = None
984

985
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
986
987
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
988
989
990
    else:
        test_data_iterator = None

991
    return train_data_iterator, valid_data_iterator, test_data_iterator