arguments.py 39.4 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
17
18
19
20

import argparse
import os

21
import torch
Raul Puri's avatar
Raul Puri committed
22

23
24
def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
25
    """Parse all arguments."""
26
27
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
28

Mohammad's avatar
Mohammad committed
29
30
31
32
33
34
35
36
37
38
39
40
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
41
    parser = _add_biencoder_args(parser)
42
    parser = _add_vit_args(parser)
43
    parser = _add_logging_args(parser)
Mohammad's avatar
Mohammad committed
44
45
46
47

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
48

Mohammad's avatar
Mohammad committed
49
    # Parse.
50
51
52
53
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
54

Mohammad's avatar
Mohammad committed
55
56
57
    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
mohammad's avatar
mohammad committed
58
    # Tensor model parallel size.
59
60
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
61
62
63
64
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
65
66
67
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
mohammad's avatar
mohammad committed
68
    # Checks.
69
70
71
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
72
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
73
74
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
75
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
76
    if args.rank == 0:
mohammad's avatar
mohammad committed
77
78
79
80
81
82
83
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)

84
85
86
87
88
89
90
91
92
93
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size
94
95
    if args.checkpoint_activations:
        args.activations_checkpoint_method = 'uniform'
slym's avatar
slym committed
96
97
98
99
        if args.rank == 0:
            print('--checkpoint-activations is no longer valid, '
                  'use --activation-checkpoint-method instead. '
                  'Defaulting to activation-checkpoint-method=uniform.')
100
    del args.checkpoint_activations
101

Jared Casper's avatar
Jared Casper committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
116
117
118
119
120
121
122
123
124
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
125
    if args.num_layers_per_virtual_pipeline_stage is not None:
126
127
128
        assert args.pipeline_model_parallel_size > 2, \
            'pipeline-model-parallel size should be greater than 2 with ' \
            'interleaved schedule'
129
130
131
132
133
134
135
136
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
137

138
139
140
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
141
        assert not args.bf16
142
        args.params_dtype = torch.half
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143
144
145
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
146
147
148
149
150
151
152
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print('accumulate and all-reduce gradients in fp32 for '
                      'bfloat16 data type.', flush=True)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
153

154
155
156
157
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

158
159
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
160
161
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
162
        assert args.use_contiguous_buffers_in_local_ddp
163

mshoeybi's avatar
mshoeybi committed
164
165
166
167
    # For torch DDP, we do not use contiguous buffer
    if args.DDP_impl == 'torch':
        args.use_contiguous_buffers_in_local_ddp = False

168
169
170
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

171
172
173
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
174

175
176
177
178
179
180
181
182
183
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
184
            'expected iteration-based learning rate warmup'
185
186
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
187
        if args.lr_warmup_fraction is not None:
188
            assert args.lr_warmup_iters == 0, \
189
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
190
191
192
193
194
195
196
197
198
199
200

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
201
        if args.lr_warmup_fraction is not None:
202
            assert args.lr_warmup_samples == 0, \
203
204
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
205

206
    # Check required arguments.
Mohammad's avatar
Mohammad committed
207
208
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
209
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
210
        _check_arg_is_not_none(args, req_arg)
211

Mohammad's avatar
Mohammad committed
212
    # Checks.
213
214
215
216
217
218
219
220
221
222
223
224
225
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
226

Mohammad's avatar
Mohammad committed
227
228
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
Jared Casper's avatar
Jared Casper committed
229
230
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
Mohammad's avatar
Mohammad committed
231
232
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
233
234
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
235
236
237
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
238
    if args.fp32_residual_connection:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
239
240
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
mohammad's avatar
mohammad committed
241
242
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
mshoeybi's avatar
mshoeybi committed
243
        assert args.tensor_model_parallel_size > 1
244
        assert args.activations_checkpoint_method is not None, \
mohammad's avatar
mohammad committed
245
            'for distribute-checkpointed-activations to work you '\
246
            'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
247

Mohammad's avatar
Mohammad committed
248
249
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
250
251


Mohammad's avatar
Mohammad committed
252
253
254
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
mohammad's avatar
mohammad committed
255
256
        print('------------------------ arguments ------------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
257
258
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
259
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
260
261
262
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
mohammad's avatar
mohammad committed
263
264
        print('-------------------- end of arguments ---------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
265
266


267
268
269
270
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


Mohammad's avatar
Mohammad committed
271
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
272
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
273

274
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
275
                       help='Number of transformer layers.')
276
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
277
                       help='Tansformer hidden size.')
278
    group.add_argument('--ffn-hidden-size', type=int, default=None,
279
280
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
281
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
282
                       help='Number of transformer attention heads.')
283
    group.add_argument('--kv-channels', type=int, default=None,
284
285
286
287
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
288
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
289
290
291
292
293
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
294
295
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
296
297
298
299
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
300
301
302
303
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
304
    group.add_argument('--onnx-safe', type=bool, required=False,
305
306
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
307
308
309
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
Mohammad's avatar
Mohammad committed
310

Mohammad's avatar
Mohammad committed
311
312
313
    return parser


314
315
316
317
318
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
319
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
320
                       help='If set, calculate and log the number of zeros in gradient.')
321
322
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
323
324
325
326
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
343
344
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
345
                       help='Enable memory logging to tensorboard.')
346
347
348
349

    return parser


Mohammad's avatar
Mohammad committed
350
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
351
352
353
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
354
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
355
356
357
358
359
360
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
361
    group.add_argument('--adam-beta1', type=float, default=0.9,
362
363
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
364
    group.add_argument('--adam-beta2', type=float, default=0.999,
365
366
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
367
    group.add_argument('--adam-eps', type=float, default=1e-08,
368
                       help='Term added to the denominator to improve'
369
                       'numerical stability')
370
371
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
372
373
374

    return parser

Mohammad's avatar
Mohammad committed
375
376

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
377
378
    group = parser.add_argument_group(title='training')

379
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
380
381
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
382
                       'parallel size times number of micro batches.')
383
384
385
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
386
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
387
388
389
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
390
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
391
392
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
393
394
395
396
397
398
399
400
401
402
403
404
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Mohammad's avatar
Mohammad committed
405
406
407
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
408
409
410
411
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
                       'across model parallel group.')
412
413
414
415
416
    group.add_argument('--activations-checkpoint-method', type=str, default=None,
                       choices=['uniform', 'block'],
                       help='1) uniform: uniformly divide the total number of '
                       'Transformer layers and checkpoint the input activation of '
                       'each divided chunk, '
slym's avatar
slym committed
417
418
419
420
                       '2) checkpoint the input activations of only a set number of '
                       'individual Transformer layers per pipeline stage and do the '
                       'rest without any checkpointing'
                       'default) do not apply activations checkpoint to any layers')
421
422
423
424
425
    group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
                       help='1) uniform: the number of Transformer layers in each '
                       'uniformly divided checkpoint unit, '
                       '2) block: the number of individual Transformer layers '
                       'to checkpoint within each pipeline stage.')
Mohammad's avatar
Mohammad committed
426
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
427
                       help='Total number of iterations to train over all '
428
429
430
431
432
433
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
434
435
436
437
438
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
439
440
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
Mohammad's avatar
Mohammad committed
441
442
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
443
    group.add_argument('--no-masked-softmax-fusion',
444
445
446
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
447
                       dest='masked_softmax_fusion')
448
449
450
451
452
453
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
454
455
456
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
457
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
458
459
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
Mohammad's avatar
Mohammad committed
460
461
462
    return parser


Mohammad's avatar
Mohammad committed
463
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
464
465
466
467
468
469
470
471
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
472
473
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
474

Mohammad's avatar
Mohammad committed
475
476
477
    return parser


Mohammad's avatar
Mohammad committed
478
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
479
480
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
481
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
482
483
484
485
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
mohammad's avatar
mohammad committed
486
                       choices=['constant', 'linear', 'cosine'],
Mohammad's avatar
Mohammad committed
487
488
489
490
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
491
492
493
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
494
495
496
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
497
498
499
500
501
502
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
503
    group.add_argument('--warmup', type=int, default=None,
504
                       help='Old lr warmup argument, do not use. Use one of the'
505
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
524
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
525
526
527
528
529
530
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
531
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
532
                       help='Do not save current optimizer.')
533
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
534
535
536
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
537
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
538
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
539
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
540
541
542
543
544
545
546
547
548
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')

    return parser


Mohammad's avatar
Mohammad committed
549
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
550
551
552
553
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
554
555
    group.add_argument('--bf16', action='store_true',
                       help='Run model in bfloat16 mode.')
mohammad's avatar
mohammad committed
556
557
558
559
560
561
562
563
564
565
566
567
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
568
569
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
570
571
572
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
573
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
574
575
576
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
577
578
579
    group.add_argument('--accumulate-allreduce-grads-in-fp32',
                       action='store_true',
                       help='Gradient accumulation and all-reduce in fp32.')
580
581
582
583
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
584
585
586
    return parser


Mohammad's avatar
Mohammad committed
587
def _add_distributed_args(parser):
588
589
    group = parser.add_argument_group(title='distributed')

590
591
592
593
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
594
595
596
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
597
598
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
Mohammad's avatar
Mohammad committed
599
600
601
602
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
603
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
604
605
                       help='which DistributedDataParallel implementation '
                       'to use.')
606
607
608
609
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
610
611
612
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
Mohammad's avatar
Mohammad committed
613
614
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
615
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
616
617
618
619
620
621
622
623
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Lawrence McAfee's avatar
Lawrence McAfee committed
624
    group.add_argument('--empty-unused-memory-level', default=0, type=int,
625
626
627
628
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
Mohammad's avatar
Mohammad committed
629
630
631
    return parser


Mohammad's avatar
Mohammad committed
632
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
633
634
635
636
637
638
639
640
641
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
642
643
644
    return parser


Mohammad's avatar
Mohammad committed
645
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
646
647
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
648
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
649
650
651
652
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
Mohammad's avatar
Mohammad committed
653
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
654
655
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
656
657
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
658
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
659
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
660
661
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
662
663
664
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
Mohammad's avatar
Mohammad committed
665
    group.add_argument('--seq-length', type=int, default=None,
666
                       help='Maximum sequence length to process.')
667
    group.add_argument('--encoder-seq-length', type=int, default=None,
668
669
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
670
671
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
672
673
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
Mostofa Patwary's avatar
Mostofa Patwary committed
674
                        ' for retriever')
675
676
677
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
678
679
680
681
682
683
684
685
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
686
687
688
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
689
                                'BertWordPieceCase',
Mohammad's avatar
Mohammad committed
690
691
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
692
693
694
695
696
697
698
699
700
701
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
702

Mohammad's avatar
Mohammad committed
703
704
    return parser

Raul Puri's avatar
Raul Puri committed
705

Mohammad's avatar
Mohammad committed
706
707
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
708

Mohammad's avatar
Mohammad committed
709
710
711
712
713
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
714

Mohammad's avatar
Mohammad committed
715
    return parser
Neel Kant's avatar
Neel Kant committed
716
717


Mostofa Patwary's avatar
Mostofa Patwary committed
718
719
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
720
721
722

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
723
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
724
                        'REALM (paper default: 128)')
725
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
726
727
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
728
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
729
730
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
731
732
733
734
735

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
736
737
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
738
739
740
741
742

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
743
744
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
745
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
746
                       help='Whether to use one sentence documents in ICT')
747
748
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
749

750
    # training
751
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
752
753
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
754
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
755
756
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
757

Neel Kant's avatar
Neel Kant committed
758
    # faiss index
Neel Kant's avatar
Neel Kant committed
759
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
760
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
761
762
763
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
764
765
766

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
767
768
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
769
    group.add_argument('--indexer-log-interval', type=int, default=1000,
770
771
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
772
    return parser
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787


def _add_vit_args(parser):
    group = parser.add_argument_group(title="vit")

    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
    group.add_argument('--img-dim', type=int, default=224,
                       help='Image size for vision classification task')
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
                       help='patch dimension used in vit')

    return parser