"vscode:/vscode.git/clone" did not exist on "8baf9a0c18c6bc700e89ad6deb200739a8242e09"
arguments.py 52.4 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad's avatar
Mohammad committed
3
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
4
5
6
7

import argparse
import os

8
import torch
Raul Puri's avatar
Raul Puri committed
9

10
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
11
    """Parse all arguments."""
12
13
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
14

Mohammad's avatar
Mohammad committed
15
16
17
18
19
20
21
22
23
24
25
26
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
27
    parser = _add_biencoder_args(parser)
28
    parser = _add_vision_args(parser)
29
    parser = _add_logging_args(parser)
mshoeybi's avatar
mshoeybi committed
30
    parser = _add_inference_args(parser)
Mohammad's avatar
Mohammad committed
31
32
33
34

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
35

Mohammad's avatar
Mohammad committed
36
    # Parse.
37
38
39
40
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
41

42
43
44
45
    # Args from environment
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
        
46
47
48
    return args

def validate_args(args, defaults={}):
mohammad's avatar
mohammad committed
49
    # Tensor model parallel size.
50
51
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
52
53
54
55
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
56
57
58
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
59
60
    args.transformer_pipeline_model_parallel_size = (
        args.pipeline_model_parallel_size - 1
61
        if args.standalone_embedding_stage else
62
63
        args.pipeline_model_parallel_size
    )
mohammad's avatar
mohammad committed
64
    # Checks.
65
66
67
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
68
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
69
70
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
71
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
72
    if args.rank == 0:
mohammad's avatar
mohammad committed
73
74
75
76
77
78
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)
79
80
81
82
83
84
    if args.pipeline_model_parallel_size > 1:
        if args.pipeline_model_parallel_split_rank is not None:
            assert args.pipeline_model_parallel_split_rank < \
                    args.pipeline_model_parallel_size, 'split rank needs'\
                    ' to be less than pipeline model parallel size ({})'.format(
                            args.pipeline_model_parallel_size)
mohammad's avatar
mohammad committed
85

86
87
88
89
90
91
92
93
94
95
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size
Vijay Korthikanti's avatar
Vijay Korthikanti committed
96

97
    if args.checkpoint_activations:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
98
99
        args.recompute_granularity = 'full'
        args.recompute_method = 'uniform'
slym's avatar
slym committed
100
101
        if args.rank == 0:
            print('--checkpoint-activations is no longer valid, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
102
103
                  'use --recompute-granularity and --recompute-method  instead. '
                  'Defaulting to recompute-granularity=full and recompute-method=uniform.')
104
    del args.checkpoint_activations
105

Vijay Korthikanti's avatar
Vijay Korthikanti committed
106
107
108
109
    if args.recompute_activations:
        args.recompute_granularity = 'selective'
    del args.recompute_activations

Jared Casper's avatar
Jared Casper committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
124
125
126
127
128
129
130
131
132
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
133
    if args.num_layers_per_virtual_pipeline_stage is not None:
134
135
136
        assert args.pipeline_model_parallel_size > 2, \
            'pipeline-model-parallel size should be greater than 2 with ' \
            'interleaved schedule'
137
138
139
140
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
Lawrence McAfee's avatar
Lawrence McAfee committed
141
            (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
142
143
144
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
145

146
147
148
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
149
        assert not args.bf16
150
        args.params_dtype = torch.half
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
151
152
153
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
154
155
156
157
158
159
160
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print('accumulate and all-reduce gradients in fp32 for '
                      'bfloat16 data type.', flush=True)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
161

162
163
164
165
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

166
167
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168
169
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
170
        assert args.use_contiguous_buffers_in_local_ddp
Sangkug Lym's avatar
Sangkug Lym committed
171
172
173
174
175
176
177
178
    else:
        if args.gradient_accumulation_fusion:
            args.gradient_accumulation_fusion = False
            if args.rank == 0:
                print('Gradient accumulation fusion to linear layer weight '
                      'gradient computation is supported only with fp32 '
                      'gradient accumulation. Setting gradient_accumulation_fusion '
                      'to False', flush=True)
179

180
181
182
183
184
    # If we use the distributed optimizer, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is on.
    if args.use_distributed_optimizer:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp
185

mshoeybi's avatar
mshoeybi committed
186
187
188
189
    # For torch DDP, we do not use contiguous buffer
    if args.DDP_impl == 'torch':
        args.use_contiguous_buffers_in_local_ddp = False

190
191
192
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

193
194
195
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
196

197
198
199
200
201
202
203
204
205
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
206
            'expected iteration-based learning rate warmup'
207
208
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
209
        if args.lr_warmup_fraction is not None:
210
            assert args.lr_warmup_iters == 0, \
211
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
212
213
214
215
216
217
218
219
220
221
222

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
223
        if args.lr_warmup_fraction is not None:
224
            assert args.lr_warmup_samples == 0, \
225
226
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
227

228
    # Check required arguments.
Mohammad's avatar
Mohammad committed
229
230
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
231
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
232
        _check_arg_is_not_none(args, req_arg)
233

Mohammad's avatar
Mohammad committed
234
    # Checks.
235
236
237
238
239
240
241
242
243
244
245
246
247
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
248

Mohammad's avatar
Mohammad committed
249
250
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
Jared Casper's avatar
Jared Casper committed
251
252
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
Mohammad's avatar
Mohammad committed
253
254
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
255
256
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
257
258
259
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
260
    if args.fp32_residual_connection:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
261
262
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
263

Vijay Korthikanti's avatar
Vijay Korthikanti committed
264
265
266
267
268
    if args.weight_decay_incr_style == 'constant':
        assert args.start_weight_decay is None
        assert args.end_weight_decay is None
        args.start_weight_decay = args.weight_decay
        args.end_weight_decay = args.weight_decay
Vijay Korthikanti's avatar
Vijay Korthikanti committed
269
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
270
271
        assert args.start_weight_decay is not None
        assert args.end_weight_decay is not None
272

Sangkug Lym's avatar
Sangkug Lym committed
273
274
275
276
277
278
279
280
281
282
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    # Persistent fused layer norm.
    if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
        args.no_persist_layer_norm = True
        if args.rank == 0:
            print('Persistent fused layer norm kernel is supported from '
                  'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
                  'Defaulting to no_persist_layer_norm=True')

Vijay Korthikanti's avatar
Vijay Korthikanti committed
283
    # Activation recomputing.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
284
    if args.distribute_saved_activations:
mshoeybi's avatar
mshoeybi committed
285
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
286
            'recomputed activations only across tensor model ' \
mshoeybi's avatar
mshoeybi committed
287
            'parallel groups'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
288
289
290
291
292
293
        assert args.recompute_granularity == 'full', \
            'distributed recompute activations is only '\
            'application to full recompute granularity'
        assert args.recompute_method is not None, \
            'for distributed recompute activations to work you '\
            'need to use a recompute method '
294
        assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
295
            'distributed recompute activations are supported for pytorch ' \
296
297
            'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
            'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
298

Vijay Korthikanti's avatar
Vijay Korthikanti committed
299
300
301
302
    if args.recompute_granularity == 'selective':
        assert args.recompute_method is None, \
            'recompute method is not yet supported for ' \
            'selective recomputing granularity'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
303
304
305
306
307
308
309

    # disable sequence parallelism when tp=1
    # to avoid change in numerics when
    # sequence_parallelism is enabled.
    if args.tensor_model_parallel_size == 1:
        args.sequence_parallel = False

Vijay Korthikanti's avatar
Vijay Korthikanti committed
310
    # disable async_tensor_model_parallel_allreduce when
Vijay Korthikanti's avatar
Vijay Korthikanti committed
311
    # model parallel memory optimization is enabled
Vijay Korthikanti's avatar
Vijay Korthikanti committed
312
313
    if args.sequence_parallel:
        args.async_tensor_model_parallel_allreduce = False
Vijay Korthikanti's avatar
Vijay Korthikanti committed
314

Mohammad's avatar
Mohammad committed
315
316
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
317
318


Mohammad's avatar
Mohammad committed
319
320
321
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
mohammad's avatar
mohammad committed
322
323
        print('------------------------ arguments ------------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
324
325
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
326
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
327
328
329
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
mohammad's avatar
mohammad committed
330
331
        print('-------------------- end of arguments ---------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
332
333


334
335
336
337
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


mshoeybi's avatar
mshoeybi committed
338
339
340
341
342
343
344
345
346
347
348
349
def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')

    return parser

    
Mohammad's avatar
Mohammad committed
350
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
351
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
352

353
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
354
                       help='Number of transformer layers.')
355
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
356
                       help='Tansformer hidden size.')
357
    group.add_argument('--ffn-hidden-size', type=int, default=None,
358
359
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
360
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
361
                       help='Number of transformer attention heads.')
362
    group.add_argument('--kv-channels', type=int, default=None,
363
364
365
366
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
367
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
368
369
370
371
372
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
373
374
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
375
376
377
378
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
379
380
381
382
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
383
    group.add_argument('--onnx-safe', type=bool, required=False,
384
385
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
386
387
388
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
rprenger's avatar
rprenger committed
389
390
    group.add_argument('--num-experts', type=int, default=None,
                       help='Number of Experts in Switch Transformer (None means no Switch)')
Mohammad's avatar
Mohammad committed
391
392
393
    return parser


394
395
396
397
398
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
399
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
400
                       help='If set, calculate and log the number of zeros in gradient.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    group.add_argument('--timing-log-level', type=int,
                       default=0, choices=range(0,3),
                       help='Granularity level to measure and report timing. '
                       '   0: report only iteration time and make sure timing '
                       '      does not introduce extra overhead.'
                       '   1: report timing for operations that are executed '
                       '      very limited times (basically once) during '
                       '      each iteration (such as gradient all-reduce) '
                       '   2: report timing for operations that migh be '
                       '      executed numerous times during each iteration. '
                       'Note that setting the level to 1 or 2 might '
                       'cause increase in iteration time.')
    group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
                       help='If not set, use barrier with level 1 time '
                       'measurements. Note that this is up to the user '
                       'to make sure calling barrier with their timers '
                       'will not result in hangs. This can happen if for '
                       'example the user adds a level 1 timer that is not '
                       'called by all ranks.',
                       dest='barrier_with_L1_time')
    group.add_argument('--timing-log-option', type=str, default='minmax',
                       choices=['max', 'minmax', 'all'],
                       help='Options for logging timing:'
                       '  max: report the max timing across all ranks'
                       '  minmax: report min and max timings across all ranks'
                       '  all: report timings of all ranks.')
427
428
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
429
430
431
432
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
449
450
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
451
                       help='Enable memory logging to tensorboard.')
452
453
454
    group.add_argument('--log-world-size-to-tensorboard',
                       action='store_true',
                       help='Enable world size logging to tensorboard.')
455
456
457
458

    return parser


Mohammad's avatar
Mohammad committed
459
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
460
461
462
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
463
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
464
465
466
467
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
468
    group.add_argument('--start-weight-decay', type=float,
469
                       help='Initial weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
470
    group.add_argument('--end-weight-decay', type=float,
471
                       help='End of run weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
472
    group.add_argument('--weight-decay-incr-style', type=str, default='constant',
473
474
                       choices=['constant', 'linear', 'cosine'],
                       help='Weight decay increment function.')
Mohammad's avatar
Mohammad committed
475
476
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
477
    group.add_argument('--adam-beta1', type=float, default=0.9,
478
479
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
480
    group.add_argument('--adam-beta2', type=float, default=0.999,
481
482
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
483
    group.add_argument('--adam-eps', type=float, default=1e-08,
484
                       help='Term added to the denominator to improve'
485
                       'numerical stability')
486
487
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
488
489
490

    return parser

Mohammad's avatar
Mohammad committed
491
492

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
493
494
    group = parser.add_argument_group(title='training')

495
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
496
497
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
498
                       'parallel size times number of micro batches.')
499
500
501
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
502
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
503
504
505
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
506
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
507
508
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
509
510
511
512
513
514
515
516
517
518
519
520
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
521
522
    group.add_argument('--recompute-activations', action='store_true',
                       help='recompute activation to allow for training '
Mohammad's avatar
Mohammad committed
523
                       'with larger models, sequences, and batch sizes.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
524
    group.add_argument('--recompute-granularity', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
525
                       choices=['full', 'selective'],
Vijay Korthikanti's avatar
Vijay Korthikanti committed
526
                       help='Checkpoint activations to allow for training '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
527
528
                       'with larger models, sequences, and batch sizes. '
                       'It is supported at two granularities 1) full: '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
529
                       'whole transformer layer is recomputed, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
530
                       '2) selective: core attention part of the transformer '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
531
                       'layer is recomputed.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
532
    group.add_argument('--distribute-saved-activations',
533
                       action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
534
                       help='If set, distribute recomputed activations '
535
                       'across model parallel group.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
536
    group.add_argument('--recompute-method', type=str, default=None,
537
538
                       choices=['uniform', 'block'],
                       help='1) uniform: uniformly divide the total number of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
539
                       'Transformer layers and recompute the input activation of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
540
                       'each divided chunk at specified granularity, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
541
                       '2) recompute the input activations of only a set number of '
slym's avatar
slym committed
542
                       'individual Transformer layers per pipeline stage and do the '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
543
544
545
                       'rest without any recomputing at specified granularity'
                       'default) do not apply activations recompute to any layers')
    group.add_argument('--recompute-num-layers', type=int, default=1,
546
                       help='1) uniform: the number of Transformer layers in each '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
547
                       'uniformly divided recompute unit, '
548
                       '2) block: the number of individual Transformer layers '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
549
                       'to recompute within each pipeline stage.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
550
551
552
553
554

    # deprecated
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
Mohammad's avatar
Mohammad committed
555
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
556
                       help='Total number of iterations to train over all '
557
558
559
560
561
562
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
563
564
565
566
567
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
568
569
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
570
571
572
    group.add_argument('--exit-signal-handler', action='store_true',
                       help='Dynamically save the checkpoint and shutdown the '
                       'training if SIGTERM is received')
Mohammad's avatar
Mohammad committed
573
574
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
575
    group.add_argument('--no-masked-softmax-fusion',
576
577
578
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
579
                       dest='masked_softmax_fusion')
580
581
582
583
584
585
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
586
587
588
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
589
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
590
591
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
slym's avatar
slym committed
592
    group.add_argument('--no-async-tensor-model-parallel-allreduce',
Sangkug Lym's avatar
Sangkug Lym committed
593
                       action='store_false',
slym's avatar
slym committed
594
595
                       help='Disable asynchronous execution of '
                       'tensor-model-parallel all-reduce with weight '
Sangkug Lym's avatar
Sangkug Lym committed
596
597
                       'gradient compuation of a column-linear layer.',
                       dest='async_tensor_model_parallel_allreduce')
Sangkug Lym's avatar
Sangkug Lym committed
598
599
600
601
602
    group.add_argument('--no-persist-layer-norm', action='store_true',
                       help='Disable using persistent fused layer norm kernel. '
                       'This kernel supports only a set of hidden sizes. Please '
                       'check persist_ln_hidden_sizes if your hidden '
                       'size is supported.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
603
    group.add_argument('--sequence-parallel', action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
604
                       help='Enable sequence parallel optimization.')
Sangkug Lym's avatar
Sangkug Lym committed
605
606
    group.add_argument('--no-gradient-accumulation-fusion',
                       action='store_false',
607
                       help='Disable fusing gradient accumulation to weight '
Sangkug Lym's avatar
Sangkug Lym committed
608
609
                       'gradient computation of linear layers',
                       dest='gradient_accumulation_fusion')
Mohammad's avatar
Mohammad committed
610
611
612
    return parser


Mohammad's avatar
Mohammad committed
613
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
614
615
616
617
618
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
619
620
621
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
Mohammad's avatar
Mohammad committed
622
623
624
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
625
626
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
627

Mohammad's avatar
Mohammad committed
628
629
630
    return parser


Mohammad's avatar
Mohammad committed
631
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
632
633
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
634
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
635
636
637
638
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
mohammad's avatar
mohammad committed
639
                       choices=['constant', 'linear', 'cosine'],
Mohammad's avatar
Mohammad committed
640
641
642
643
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
644
645
646
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
647
648
649
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
650
651
652
653
654
655
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
656
    group.add_argument('--warmup', type=int, default=None,
657
                       help='Old lr warmup argument, do not use. Use one of the'
658
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
659
660
661
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
662
    group.add_argument('--override-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
663
664
665
666
667
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
668
    group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
669
670
671
672
673
674
675
676
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
677
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
678
679
680
681
682
683
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
684
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
685
                       help='Do not save current optimizer.')
686
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
687
688
689
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
690
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
691
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
692
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
693
694
695
696
697
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')
698
699
700
701
702
    group.add_argument('--no-initialization', action='store_false',
                       help='Do not perform initialization when building model, '
                       'can reduce startup time when definitely loading from a '
                       'checkpoint',
                       dest='perform_initialization')
703
704
705
    group.add_argument('--use-checkpoint-args', action='store_true',
                       help='Override any command line arguments with arguments '
                       'from the checkpoint')
Mohammad's avatar
Mohammad committed
706
707
708
709

    return parser


Mohammad's avatar
Mohammad committed
710
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
711
712
713
714
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
715
716
    group.add_argument('--bf16', action='store_true',
                       help='Run model in bfloat16 mode.')
mohammad's avatar
mohammad committed
717
718
719
720
721
722
723
724
725
726
727
728
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
729
730
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
731
732
733
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
734
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
735
736
737
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
738
739
740
    group.add_argument('--accumulate-allreduce-grads-in-fp32',
                       action='store_true',
                       help='Gradient accumulation and all-reduce in fp32.')
741
742
743
744
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
745
746
747
    return parser


Mohammad's avatar
Mohammad committed
748
def _add_distributed_args(parser):
749
750
    group = parser.add_argument_group(title='distributed')

751
752
753
754
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
755
756
757
    group.add_argument('--pipeline-model-parallel-split-rank',
                       type=int, default=None,
                       help='Rank where encoder and decoder should be split.')
758
759
760
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
761
762
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
Mohammad's avatar
Mohammad committed
763
764
765
766
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
767
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
768
769
                       help='which DistributedDataParallel implementation '
                       'to use.')
770
771
772
773
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
774
775
776
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
777
778
779
780
    group.add_argument('--use-ring-exchange-p2p', action='store_true',
                       default=False, help='If set, use custom-built ring exchange '
                       'for p2p communications. Note that this option will require '
                       'a custom built image that support ring-exchange p2p.')
Mohammad's avatar
Mohammad committed
781
782
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
783
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
784
785
786
787
788
789
790
791
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Lawrence McAfee's avatar
Lawrence McAfee committed
792
    group.add_argument('--empty-unused-memory-level', default=0, type=int,
793
794
795
796
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
797
    group.add_argument('--standalone-embedding-stage', action='store_true',
Lawrence McAfee's avatar
Lawrence McAfee committed
798
799
                       default=False, help='If set, *input* embedding layer '
                       'is placed on its own pipeline stage, without any '
Lawrence McAfee's avatar
Lawrence McAfee committed
800
801
                       'transformer layers. (For T5, this flag currently only '
                       'affects the encoder embedding.)')
802
803
    group.add_argument('--use-distributed-optimizer', action='store_true',
                       help='Use distributed optimizer.')
804

Mohammad's avatar
Mohammad committed
805
806
807
    return parser


Mohammad's avatar
Mohammad committed
808
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
809
810
811
812
813
814
815
816
817
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
818
819
820
    return parser


Mohammad's avatar
Mohammad committed
821
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
822
823
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
824
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
825
826
827
828
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
Mohammad's avatar
Mohammad committed
829
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
830
831
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
832
833
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
834
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
835
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
836
837
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
838
839
840
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
Mohammad's avatar
Mohammad committed
841
    group.add_argument('--seq-length', type=int, default=None,
842
                       help='Maximum sequence length to process.')
843
    group.add_argument('--encoder-seq-length', type=int, default=None,
844
845
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
846
847
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
848
849
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
Mostofa Patwary's avatar
Mostofa Patwary committed
850
                        ' for retriever')
851
852
853
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
854
855
856
857
858
859
860
861
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
862
863
864
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
865
                                'BertWordPieceCase',
866
867
                                'GPT2BPETokenizer',
                                'SentencePieceTokenizer'],
Mohammad's avatar
Mohammad committed
868
                       help='What type of tokenizer to use.')
869
    group.add_argument('--tokenizer-model', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
870
                       help='Sentencepiece tokenizer model.')
871
872
873
874
875
876
877
878
879
880
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
881

Mohammad's avatar
Mohammad committed
882
883
    return parser

Raul Puri's avatar
Raul Puri committed
884

Mohammad's avatar
Mohammad committed
885
886
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
887

Mohammad's avatar
Mohammad committed
888
889
890
891
892
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
893

Mohammad's avatar
Mohammad committed
894
    return parser
Neel Kant's avatar
Neel Kant committed
895
896


Mostofa Patwary's avatar
Mostofa Patwary committed
897
898
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
899
900
901

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
902
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
903
                        'REALM (paper default: 128)')
904
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
905
906
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
907
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
908
909
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
910
911
912
913
914

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
915
916
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
917
918
919
920
921

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
922
923
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
924
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
925
                       help='Whether to use one sentence documents in ICT')
926
927
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
928

929
    # training
930
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
931
932
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
933
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
934
935
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
936

Neel Kant's avatar
Neel Kant committed
937
    # faiss index
Neel Kant's avatar
Neel Kant committed
938
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
939
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
940
941
942
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
943
944
945

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
946
947
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
948
    group.add_argument('--indexer-log-interval', type=int, default=1000,
949
950
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
951
    return parser
952
953


954
955
def _add_vision_args(parser):
    group = parser.add_argument_group(title="vision")
956

957
    # general vision arguements
958
959
    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
960
961
962
963
    group.add_argument('--img-h', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-w', type=int, default=224,
                       help='Image height for vision classification task')
964
965
966
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
967
                       help='patch dimension')
968
969
970
971
972
973
974
    group.add_argument('--classes-fraction', type=float, default=1.0,
                       help='training with fraction of classes.')
    group.add_argument('--data-per-class-fraction', type=float, default=1.0,
                       help='training with fraction of data per class.')
    group.add_argument('--no-data-sharding', action='store_false',
                       help='Disable data sharding.',
                       dest='data_sharding')
975
976
977
978
    group.add_argument('--head-lr-mult', type=float, default=1.0,
                       help='learning rate multiplier for head during finetuning')

    # pretraining type and backbone selection`
Vijay Korthikanti's avatar
Vijay Korthikanti committed
979
980
    group.add_argument('--vision-pretraining', action='store_true',
                       help='flag to indicate vision pretraining')
981
    group.add_argument('--vision-pretraining-type', type=str, default='classify',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
982
                       choices=['classify', 'inpaint', 'dino'],
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
                       help='pretraining objectives')
    group.add_argument('--vision-backbone-type', type=str, default='vit',
                       choices=['vit', 'mit', 'swin'],
                       help='backbone types types')
    group.add_argument('--swin-backbone-type', type=str, default='tiny',
                       choices=['tiny', 'base', 'h3'],
                       help='pretraining objectives')
    
    # inpainting arguments
    group.add_argument('--mask-type', type=str, default='random',
                       choices=['random', 'row'],
                       help='mask types')
    group.add_argument('--mask-factor', type=float, default=1.0,
                       help='mask size scaling parameter')
 
    # dino arguments
    group.add_argument('--iter-per-epoch', type=int, default=1250,
                       help='iterations per epoch')
    group.add_argument('--dino-local-img-size', type=int, default=96,
                       help='Image size for vision classification task')
    group.add_argument('--dino-local-crops-number', type=int, default=10,
                       help='Number of local crops')
    group.add_argument('--dino-head-hidden-size', type=int, default=2048,
                       help='Hidden dimension size in dino head')
    group.add_argument('--dino-bottleneck-size', type=int, default=256,
                       help='Bottle neck dimension in dino head ')
    group.add_argument('--dino-freeze-last-layer', type=float, default=1,
                       help='Freezing last layer weights')
    group.add_argument('--dino-norm-last-layer', action='store_true',
                       help='Disable Norm in last layer.')
    group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
                       help='warump teacher temperature')
    group.add_argument('--dino-teacher-temp', type=float, default=0.07,
                       help='teacher temperature')
    group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
                       help='warmup teacher temperaure epochs')
1019
1020

    return parser