arguments.py 68 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad's avatar
Mohammad committed
3
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
4
5

import argparse
liangjing's avatar
v1  
liangjing committed
6
import dataclasses
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
7
import json
Raul Puri's avatar
Raul Puri committed
8
import os
9
import torch
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
10
11
import types

liangjing's avatar
v1  
liangjing committed
12
import torch.nn.functional as F
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
13
14
15
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path

liangjing's avatar
v1  
liangjing committed
16
from megatron.core.transformer import TransformerConfig
Raul Puri's avatar
Raul Puri committed
17

18
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
19
    """Parse all arguments."""
20
21
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
22

Mohammad's avatar
Mohammad committed
23
24
25
26
27
28
29
30
31
32
33
34
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
35
    parser = _add_biencoder_args(parser)
36
    parser = _add_vision_args(parser)
37
    parser = _add_logging_args(parser)
mshoeybi's avatar
mshoeybi committed
38
    parser = _add_inference_args(parser)
39
    parser = _add_transformer_engine_args(parser)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
40
    parser = _add_retro_args(parser)
Mohammad's avatar
Mohammad committed
41
42
43
44

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
45

Mohammad's avatar
Mohammad committed
46
    # Parse.
47
48
49
50
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
51

52
    # Args from environment
liangjing's avatar
v1  
liangjing committed
53
54
55
    #args.rank = int(os.getenv('RANK', '0'))
    #args.world_size = int(os.getenv("WORLD_SIZE", '1'))

56
57
58
    return args

def validate_args(args, defaults={}):
mohammad's avatar
mohammad committed
59
    # Tensor model parallel size.
60
61
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
62
63
64
65
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
66
67
68
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
69
70
    args.transformer_pipeline_model_parallel_size = (
        args.pipeline_model_parallel_size - 1
71
        if args.standalone_embedding_stage else
72
73
        args.pipeline_model_parallel_size
    )
mohammad's avatar
mohammad committed
74
    # Checks.
75
76
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
liangjing's avatar
v1  
liangjing committed
77
    assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\
78
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
79
80
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
81
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
82
    if args.rank == 0:
mohammad's avatar
mohammad committed
83
84
85
86
87
88
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)
89
90
91
92
93
94
    if args.pipeline_model_parallel_size > 1:
        if args.pipeline_model_parallel_split_rank is not None:
            assert args.pipeline_model_parallel_split_rank < \
                    args.pipeline_model_parallel_size, 'split rank needs'\
                    ' to be less than pipeline model parallel size ({})'.format(
                            args.pipeline_model_parallel_size)
mohammad's avatar
mohammad committed
95

96
97
98
99
100
101
102
103
104
105
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size
Vijay Korthikanti's avatar
Vijay Korthikanti committed
106

107
    if args.checkpoint_activations:
slym's avatar
slym committed
108
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
109
110
111
            print('--checkpoint-activations is no longer valid, use --recompute-activations, '
                  'or, for more control, --recompute-granularity and --recompute-method.')
        exit()
112
    del args.checkpoint_activations
113

Vijay Korthikanti's avatar
Vijay Korthikanti committed
114
115
116
117
    if args.recompute_activations:
        args.recompute_granularity = 'selective'
    del args.recompute_activations

Jared Casper's avatar
Jared Casper committed
118
119
120
121
122
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
liangjing's avatar
v1  
liangjing committed
123
        if getattr(args, key, None) is not None:
Jared Casper's avatar
Jared Casper committed
124
125
126
127
128
129
130
131
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
132
133
134
135
136
137
138
139
140
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
141
    if args.num_layers_per_virtual_pipeline_stage is not None:
142
143
144
        assert args.pipeline_model_parallel_size > 2, \
            'pipeline-model-parallel size should be greater than 2 with ' \
            'interleaved schedule'
145
146
147
148
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
Lawrence McAfee's avatar
Lawrence McAfee committed
149
            (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
150
151
152
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
153

154
155
156
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
157
        assert not args.bf16
158
        args.params_dtype = torch.half
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
159
160
161
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
162
163
164
165
166
167
168
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print('accumulate and all-reduce gradients in fp32 for '
                      'bfloat16 data type.', flush=True)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
169

170
171
172
173
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

174
175
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
176
177
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
178
        assert args.use_contiguous_buffers_in_local_ddp
179

180
181
182
183
184
    # If we use the distributed optimizer, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is on.
    if args.use_distributed_optimizer:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp
185

mshoeybi's avatar
mshoeybi committed
186
187
188
189
    # For torch DDP, we do not use contiguous buffer
    if args.DDP_impl == 'torch':
        args.use_contiguous_buffers_in_local_ddp = False

190
191
192
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

193
194
195
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
196

197
198
199
200
201
202
203
    # Support for variable sequence lengths across batches/microbatches.
    # set it if the dataloader supports generation of variable sequence lengths
    # across batches/microbatches. Due to additional communication overhead
    # during pipeline parallelism, it should not be set if sequence length
    # is constant during training.
    args.variable_seq_lengths = False

204
205
206
207
208
209
210
211
212
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
213
            'expected iteration-based learning rate warmup'
214
215
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
216
        if args.lr_warmup_fraction is not None:
217
            assert args.lr_warmup_iters == 0, \
218
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
219
220
221
222
223
224
225
226
227
228
229

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
230
        if args.lr_warmup_fraction is not None:
231
            assert args.lr_warmup_samples == 0, \
232
233
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
234

235
    if args.num_layers is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
236
237
        assert args.encoder_num_layers is None, \
            'cannot have both num-layers and encoder-num-layers specified'
238
239
        args.encoder_num_layers = args.num_layers
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
240
241
        assert args.encoder_num_layers is not None, \
            'either num-layers or encoder-num-layers should be specified'
242
243
        args.num_layers = args.encoder_num_layers

244
    # Check required arguments.
Mohammad's avatar
Mohammad committed
245
246
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
247
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
248
        _check_arg_is_not_none(args, req_arg)
249

Mohammad's avatar
Mohammad committed
250
    # Checks.
251
252
253
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

254
255
256
257
258
259
260
261
    if args.swiglu:
        # reduce the dimnesion for MLP since projections happens on
        # two linear layers. this keeps the number of paramters in
        # the same ballpark as the counterpart with 4*h size
        # we keep it a multiple of 64, which means the actual tensor size
        # will be a multiple of 64 / tp_size
        args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64

262
263
264
265
266
267
268
269
270
271
    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
272

Mohammad's avatar
Mohammad committed
273
274
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
Jared Casper's avatar
Jared Casper committed
275
276
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
Mohammad's avatar
Mohammad committed
277
278
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
279
280
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
281
282
283
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
284
    if args.fp32_residual_connection:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
285
286
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
287

Vijay Korthikanti's avatar
Vijay Korthikanti committed
288
289
290
291
292
    if args.weight_decay_incr_style == 'constant':
        assert args.start_weight_decay is None
        assert args.end_weight_decay is None
        args.start_weight_decay = args.weight_decay
        args.end_weight_decay = args.weight_decay
Vijay Korthikanti's avatar
Vijay Korthikanti committed
293
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
294
295
        assert args.start_weight_decay is not None
        assert args.end_weight_decay is not None
296

Sangkug Lym's avatar
Sangkug Lym committed
297
298
299
300
301
302
303
304
305
306
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    # Persistent fused layer norm.
    if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
        args.no_persist_layer_norm = True
        if args.rank == 0:
            print('Persistent fused layer norm kernel is supported from '
                  'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
                  'Defaulting to no_persist_layer_norm=True')

Vijay Korthikanti's avatar
Vijay Korthikanti committed
307
    # Activation recomputing.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
308
    if args.distribute_saved_activations:
mshoeybi's avatar
mshoeybi committed
309
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
310
            'recomputed activations only across tensor model ' \
mshoeybi's avatar
mshoeybi committed
311
            'parallel groups'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
312
313
314
315
316
317
        assert args.recompute_granularity == 'full', \
            'distributed recompute activations is only '\
            'application to full recompute granularity'
        assert args.recompute_method is not None, \
            'for distributed recompute activations to work you '\
            'need to use a recompute method '
liangjing's avatar
v1  
liangjing committed
318
        assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
319
            'distributed recompute activations are supported for pytorch ' \
320
321
            'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
            'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
322

Vijay Korthikanti's avatar
Vijay Korthikanti committed
323
324
325
326
    if args.recompute_granularity == 'selective':
        assert args.recompute_method is None, \
            'recompute method is not yet supported for ' \
            'selective recomputing granularity'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
327
328
329
330
331
332
333

    # disable sequence parallelism when tp=1
    # to avoid change in numerics when
    # sequence_parallelism is enabled.
    if args.tensor_model_parallel_size == 1:
        args.sequence_parallel = False

Vijay Korthikanti's avatar
Vijay Korthikanti committed
334
    # disable async_tensor_model_parallel_allreduce when
Vijay Korthikanti's avatar
Vijay Korthikanti committed
335
    # model parallel memory optimization is enabled
Vijay Korthikanti's avatar
Vijay Korthikanti committed
336
337
    if args.sequence_parallel:
        args.async_tensor_model_parallel_allreduce = False
Vijay Korthikanti's avatar
Vijay Korthikanti committed
338

339
340
341
342
343
344
345
346
347
348
    if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
        if args.sequence_parallel:
            raise RuntimeError(
                "Using sequence parallelism requires setting the environment variable "
                "CUDA_DEVICE_MAX_CONNECTIONS to 1")
        if args.async_tensor_model_parallel_allreduce:
            raise RuntimeError(
                "Using async gradient all reduce requires setting the environment "
                "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")

349
350
351
352
    # Disable bias gelu fusion if we are disabling bias altogether
    if not args.add_bias_linear:
        args.bias_gelu_fusion = False

liangjing's avatar
v1  
liangjing committed
353
354
355
356
357
358
359
360
361
362
363
364
    # Retro checks.
    if args.retro_add_retriever:

        # Sequence parallelism unsupported.
        assert not args.sequence_parallel, \
            "retro currently does not support sequence parallelism."

        # Pipeline parallelism unsupported.
        assert args.pipeline_model_parallel_size == 1, \
            "retro currently does not support pipeline parallelism."

        # Load retro args.
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
365
        retro_args_path = get_retro_args_path(args.retro_workdir)
liangjing's avatar
v1  
liangjing committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        assert os.path.exists(retro_args_path), "retro workdir missing args.json"
        with open(retro_args_path) as f:
            retro_args = types.SimpleNamespace(**json.load(f))
            retro_args.retro_return_doc_ids = args.retro_return_doc_ids
            retro_args.retro_gpt_retrieved_length = \
                args.retro_num_retrieved_chunks * \
                retro_args.retro_gpt_chunk_length
            set_retro_args(retro_args)

    # Legacy RoPE arguments
    if args.use_rotary_position_embeddings:
        args.position_embedding_type = 'rope'

    # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
    # don't allow it to keep things simple
    if not args.add_position_embedding and args.position_embedding_type != 'rope':
        raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
383
384
385
386
387
388

    # Print arguments.
    _print_args("arguments", args)
    retro_args = get_retro_args()
    if retro_args and args != retro_args:
        _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
389

Mohammad's avatar
Mohammad committed
390
    return args
Mohammad's avatar
Mohammad committed
391
392


Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
393
def _print_args(title, args):
Mohammad's avatar
Mohammad committed
394
395
    """Print arguments."""
    if args.rank == 0:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
396
        print(f'------------------------ {title} ------------------------',
mohammad's avatar
mohammad committed
397
              flush=True)
Mohammad's avatar
Mohammad committed
398
399
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
400
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
401
402
403
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
404
        print(f'-------------------- end of {title} ---------------------',
mohammad's avatar
mohammad committed
405
              flush=True)
Mohammad's avatar
Mohammad committed
406
407


408
409
410
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)

liangjing's avatar
v1  
liangjing committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def core_transformer_config_from_args(args):

    # Translate args to core transformer configuration
    kw_args = {}
    for f in dataclasses.fields(TransformerConfig):
        if hasattr(args, f.name):
            kw_args[f.name] = getattr(args, f.name)
    kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
    kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
    kw_args['deallocate_pipeline_outputs'] = True
    kw_args['pipeline_dtype'] = args.params_dtype
    kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
    if args.swiglu:
        kw_args['activation_func'] = F.silu
        kw_args['gated_linear_unit'] = True
        kw_args['bias_gelu_fusion'] = False
    if args.init_method_xavier_uniform:
        kw_args['init_method'] = torch.nn.init.xavier_uniform_
        kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
    if args.group_query_attention:
        kw_args['num_query_groups'] = args.num_query_groups
    else:
        kw_args['num_query_groups'] = None

    return TransformerConfig(**kw_args)
436

437
438
439
def _add_transformer_engine_args(parser):
    group = parser.add_argument_group(title='Transformer-Engine')

liangjing's avatar
v1  
liangjing committed
440
441
442
443
    group.add_argument('--fp8-format', default=None,
                       choices=['e4m3', 'hybrid'],
                       help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass',
                       dest='fp8')
444
    group.add_argument('--fp8-margin', type=int, default=0,
liangjing's avatar
v1  
liangjing committed
445
446
                       help='Scaling margin for fp8',
                       dest='fp8_margin')
447
    group.add_argument('--fp8-interval', type=int, default=1,
liangjing's avatar
v1  
liangjing committed
448
449
                       help='Scaling update interval for fp8',
                       dest='fp8_interval')
450
    group.add_argument('--fp8-amax-history-len', type=int, default=1,
liangjing's avatar
v1  
liangjing committed
451
452
                       help='Number of steps for which amax history is recorded per tensor',
                       dest='fp8_amax_history_len')
453
454
455
456
    group.add_argument('--fp8-amax-compute-algo', default='most_recent',
                       choices=['most_recent', 'max'],
                       help='Algorithm for computing amax from history',
                       dest='fp8_amax_compute_algo')
liangjing's avatar
v1  
liangjing committed
457
458
459
460
461
462
463
464
465
466
467
    group.add_argument('--no-fp8-wgrad', action='store_false',
                       help='Execute wgrad in higher precision even for FP8 runs',
                       dest='fp8_wgrad')
    group.add_argument('--transformer-impl', default='local',
                       choices=['local', 'transformer_engine'],
                       help='Which Transformer implementation to use.',
                       dest='transformer_impl')
    group.add_argument('--normalization', default='LayerNorm',
                       choices=['LayerNorm', 'RMSNorm'],
                       help='Which normalization technique to use.',
                       dest='normalization')
468
469
470

    return parser

mshoeybi's avatar
mshoeybi committed
471
472
473
474
475
476
477
478
def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')
479
480
481
482
483
    group.add_argument('--max-tokens-to-oom',
                       type=int, default=12000,
                       help='Maximum number of tokens during inference'
                       'tokens here is # in prompt + # to generate'
                       'Allows us to throw an error before OOM crashes server')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
484
485
486
487
488
489
490
491
492
    group.add_argument('--output-bert-embeddings', action='store_true',
                       help='Output Bert embeddings (via mean pooling) from '
                       'model, rather than its binary head output or entire '
                       'hidden batch.')
    group.add_argument('--bert-embedder-type', default="megatron",
                       choices=["megatron", "huggingface"],
                       help='Select either Megatron or Huggingface as the '
                       'Bert embedder.')

mshoeybi's avatar
mshoeybi committed
493
494
    return parser

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539

def _add_retro_args(parser):
    group = parser.add_argument_group(title='retro')

    group.add_argument('--retro-workdir', default=None,
                       help='Retro working directory, which contains the '
                       'preprocessed data for for pretraining. This directory '
                       'is built during preprocessing (see '
                       'tools/retro/README.md), and contains subdirectories '
                       'for the chunk database and pretraining neighbors.')
    group.add_argument('--retro-add-retriever',
                       action='store_true', default=False,
                       help='Add a retriever to the transformer, for use in '
                       'pretraining a Retro model.')
    group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
                       help='Set number of training iterations for cyclic '
                       'Retro training.')
    group.add_argument('--retro-encoder-layers', type=int, default=2,
                       help='Number of layers to use for the retrieval '
                       'encoder.')
    group.add_argument('--retro-encoder-hidden-dropout',
                       type=float, default=0.1, help='Hidden dropout for '
                       'retrieval encoder.')
    group.add_argument('--retro-encoder-attention-dropout',
                       type=float, default=0.1, help='Attention dropout for '
                       'retrieval encoder.')
    group.add_argument("--retro-num-neighbors", type=int, default=2,
                       help='Number of neighbors to retrieve during '
                       'pretraining.')
    group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
                       help='Number of chunks to retrieve from the retrieval '
                       'database.')
    group.add_argument("--retro-return-doc-ids", action="store_true",
                       help="Turn this on when preprocessing retro data.")

    # Enforce argument naming convention.
    for action in group._group_actions:
        prefix = action.dest.split("_")[0]
        assert prefix == "retro", \
            "Retro args must be prefixed with '--retro-*', for consistent " \
            "styling. Please fix '%s'." % ", ".join(action.option_strings)

    return parser


Mohammad's avatar
Mohammad committed
540
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
541
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
542

543
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
544
                       help='Number of transformer layers.')
545
546
547
548
    group.add_argument('--encoder-num-layers', type=int, default=None,
                       help='Number of encoder transformer layers.')
    group.add_argument('--decoder-num-layers', type=int, default=None,
                       help='Number of decoder transformer layers.')
549
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
550
                       help='Tansformer hidden size.')
551
    group.add_argument('--ffn-hidden-size', type=int, default=None,
552
553
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
554
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
555
                       help='Number of transformer attention heads.')
556
    group.add_argument('--kv-channels', type=int, default=None,
557
558
559
560
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
liangjing's avatar
v1  
liangjing committed
561
562
563
564
    group.add_argument('--group-query-attention', action='store_true',
                          help='Use group-query attention.')
    group.add_argument('--num-query-groups', type=int, default=1)

565
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
566
567
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
liangjing's avatar
v1  
liangjing committed
568
569
570
    group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
                       choices=['learned_absolute', 'rope'],
                       help='Position embedding type.')
Mostofa Patwary's avatar
Mostofa Patwary committed
571
    group.add_argument('--use-rotary-position-embeddings', action='store_true',
liangjing's avatar
v1  
liangjing committed
572
573
                       help='Use rotary positional embeddings or not. '
                       'Deprecated: use --position-embedding-type')
Mostofa Patwary's avatar
Mostofa Patwary committed
574
    group.add_argument('--rotary-percent', type=float, default=1.0,
liangjing's avatar
v1  
liangjing committed
575
576
577
                       help='Percent of rotary dimension to use, default 100%%')
    group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
                       help='Sequence length interpolation factor for rotary embeddings.')
Mostofa Patwary's avatar
Mostofa Patwary committed
578
579
    group.add_argument('--no-position-embedding',
                       action='store_false',
liangjing's avatar
v1  
liangjing committed
580
                       help='Disable position embedding. Deprecated: use --position-embedding-type',
Mostofa Patwary's avatar
Mostofa Patwary committed
581
                       dest='add_position_embedding')
Mohammad's avatar
Mohammad committed
582
583
584
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
585
586
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mostofa Patwary's avatar
Mostofa Patwary committed
587
    group.add_argument('--apply-layernorm-1p', action='store_true',
588
589
                       help='Adjust LayerNorm weights such that they are centered '
                       'around zero. This improves numerical stability.')
Mohammad's avatar
Mohammad committed
590
591
592
593
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
594
595
596
597
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
598
599
600
601
    group.add_argument('--squared-relu', action='store_true',
                       help='Use squared relu activation instead of default gelu')
    group.add_argument('--swiglu', action='store_true',
                       help='Use gated linear units and SiLU activation instead of default gelu')
602
    group.add_argument('--onnx-safe', type=bool, required=False,
603
604
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
605
606
607
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
rprenger's avatar
rprenger committed
608
609
    group.add_argument('--num-experts', type=int, default=None,
                       help='Number of Experts in Switch Transformer (None means no Switch)')
610
611
    group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
                       help='Untie embeddings and output weights.'),
liangjing's avatar
v1  
liangjing committed
612
613
    group.add_argument('--embedding-weights-in-fp32', action='store_true',
                       help='Cast word embedding weights to fp32 before embedding fwd.'),
Mohammad's avatar
Mohammad committed
614
615
616
    return parser


617
618
619
620
621
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
622
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
623
                       help='If set, calculate and log the number of zeros in gradient.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    group.add_argument('--timing-log-level', type=int,
                       default=0, choices=range(0,3),
                       help='Granularity level to measure and report timing. '
                       '   0: report only iteration time and make sure timing '
                       '      does not introduce extra overhead.'
                       '   1: report timing for operations that are executed '
                       '      very limited times (basically once) during '
                       '      each iteration (such as gradient all-reduce) '
                       '   2: report timing for operations that migh be '
                       '      executed numerous times during each iteration. '
                       'Note that setting the level to 1 or 2 might '
                       'cause increase in iteration time.')
    group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
                       help='If not set, use barrier with level 1 time '
                       'measurements. Note that this is up to the user '
                       'to make sure calling barrier with their timers '
                       'will not result in hangs. This can happen if for '
                       'example the user adds a level 1 timer that is not '
                       'called by all ranks.',
                       dest='barrier_with_L1_time')
    group.add_argument('--timing-log-option', type=str, default='minmax',
                       choices=['max', 'minmax', 'all'],
                       help='Options for logging timing:'
                       '  max: report the max timing across all ranks'
                       '  minmax: report min and max timings across all ranks'
                       '  all: report timings of all ranks.')
650
651
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
652
653
654
655
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
672
673
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
674
                       help='Enable memory logging to tensorboard.')
675
676
677
    group.add_argument('--log-world-size-to-tensorboard',
                       action='store_true',
                       help='Enable world size logging to tensorboard.')
678
679
680
681

    return parser


Mohammad's avatar
Mohammad committed
682
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
683
684
685
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
686
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
687
688
689
690
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
691
    group.add_argument('--start-weight-decay', type=float,
692
                       help='Initial weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
693
    group.add_argument('--end-weight-decay', type=float,
694
                       help='End of run weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
695
    group.add_argument('--weight-decay-incr-style', type=str, default='constant',
696
697
                       choices=['constant', 'linear', 'cosine'],
                       help='Weight decay increment function.')
Mohammad's avatar
Mohammad committed
698
699
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
700
    group.add_argument('--adam-beta1', type=float, default=0.9,
701
702
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
703
    group.add_argument('--adam-beta2', type=float, default=0.999,
704
705
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
706
    group.add_argument('--adam-eps', type=float, default=1e-08,
707
                       help='Term added to the denominator to improve'
708
                       'numerical stability')
709
710
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
711
712
713

    return parser

Mohammad's avatar
Mohammad committed
714
715

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
716
717
    group = parser.add_argument_group(title='training')

718
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
719
720
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
721
                       'parallel size times number of micro batches.')
722
723
724
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
725
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
726
727
728
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
729
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
730
731
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
732
733
734
735
736
737
738
739
740
741
742
743
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
744
745
    group.add_argument('--recompute-activations', action='store_true',
                       help='recompute activation to allow for training '
Mohammad's avatar
Mohammad committed
746
                       'with larger models, sequences, and batch sizes.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
747
    group.add_argument('--recompute-granularity', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
748
                       choices=['full', 'selective'],
Vijay Korthikanti's avatar
Vijay Korthikanti committed
749
                       help='Checkpoint activations to allow for training '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
750
751
                       'with larger models, sequences, and batch sizes. '
                       'It is supported at two granularities 1) full: '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
752
                       'whole transformer layer is recomputed, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
753
                       '2) selective: core attention part of the transformer '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
754
                       'layer is recomputed.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
755
    group.add_argument('--distribute-saved-activations',
756
                       action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
757
                       help='If set, distribute recomputed activations '
758
                       'across model parallel group.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
759
    group.add_argument('--recompute-method', type=str, default=None,
760
761
                       choices=['uniform', 'block'],
                       help='1) uniform: uniformly divide the total number of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
762
                       'Transformer layers and recompute the input activation of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
763
                       'each divided chunk at specified granularity, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
764
                       '2) recompute the input activations of only a set number of '
slym's avatar
slym committed
765
                       'individual Transformer layers per pipeline stage and do the '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
766
767
                       'rest without any recomputing at specified granularity'
                       'default) do not apply activations recompute to any layers')
liangjing's avatar
v1  
liangjing committed
768
    group.add_argument('--recompute-num-layers', type=int, default=None,
769
                       help='1) uniform: the number of Transformer layers in each '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
770
                       'uniformly divided recompute unit, '
771
                       '2) block: the number of individual Transformer layers '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
772
                       'to recompute within each pipeline stage.')
liangjing's avatar
v1  
liangjing committed
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    group.add_argument('--profile', action='store_true',
                       help='Enable nsys profiling. When using this option, nsys '
                       'options should be specified in commandline. An example '
                       'nsys commandline is `nsys profile -s none -t nvtx,cuda '
                       '-o <path/to/output_file> --force-overwrite true '
                       '--capture-range=cudaProfilerApi '
                       '--capture-range-end=stop`.')
    group.add_argument('--profile-step-start', type=int, default=10,
                       help='Gloable step to start profiling.')
    group.add_argument('--profile-step-end', type=int, default=12,
                       help='Gloable step to stop profiling.')
    group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
                       help='Global ranks to profile.')

Vijay Korthikanti's avatar
Vijay Korthikanti committed
787
788
789
790
791

    # deprecated
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
Mohammad's avatar
Mohammad committed
792
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
793
                       help='Total number of iterations to train over all '
794
795
796
797
798
799
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
800
801
802
803
804
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
805
806
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
807
808
809
    group.add_argument('--exit-signal-handler', action='store_true',
                       help='Dynamically save the checkpoint and shutdown the '
                       'training if SIGTERM is received')
Mohammad's avatar
Mohammad committed
810
811
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
812
    group.add_argument('--no-masked-softmax-fusion',
813
814
815
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
816
                       dest='masked_softmax_fusion')
817
818
819
820
821
822
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
823
824
825
    group.add_argument('--use-flash-attn', action='store_true',
                       help='use FlashAttention implementation of attention. '
                       'https://arxiv.org/abs/2205.14135')
826
827
828
    group.add_argument('--disable-bias-linear', action='store_false',
                       help='Disable bias in the linear layers',
                       dest='add_bias_linear')
829
830
831
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
832
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
833
834
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
slym's avatar
slym committed
835
    group.add_argument('--no-async-tensor-model-parallel-allreduce',
Sangkug Lym's avatar
Sangkug Lym committed
836
                       action='store_false',
slym's avatar
slym committed
837
838
                       help='Disable asynchronous execution of '
                       'tensor-model-parallel all-reduce with weight '
Sangkug Lym's avatar
Sangkug Lym committed
839
840
                       'gradient compuation of a column-linear layer.',
                       dest='async_tensor_model_parallel_allreduce')
Sangkug Lym's avatar
Sangkug Lym committed
841
842
843
844
845
    group.add_argument('--no-persist-layer-norm', action='store_true',
                       help='Disable using persistent fused layer norm kernel. '
                       'This kernel supports only a set of hidden sizes. Please '
                       'check persist_ln_hidden_sizes if your hidden '
                       'size is supported.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
846
    group.add_argument('--sequence-parallel', action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
847
                       help='Enable sequence parallel optimization.')
Sangkug Lym's avatar
Sangkug Lym committed
848
849
    group.add_argument('--no-gradient-accumulation-fusion',
                       action='store_false',
850
                       help='Disable fusing gradient accumulation to weight '
Sangkug Lym's avatar
Sangkug Lym committed
851
852
                       'gradient computation of linear layers',
                       dest='gradient_accumulation_fusion')
Mohammad's avatar
Mohammad committed
853
854
855
    return parser


Mohammad's avatar
Mohammad committed
856
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
857
858
859
860
861
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
862
863
864
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
Mohammad's avatar
Mohammad committed
865
866
867
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
868
869
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
870

Mohammad's avatar
Mohammad committed
871
872
873
    return parser


Mohammad's avatar
Mohammad committed
874
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
875
876
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
877
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
878
879
880
881
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
882
                       choices=['constant', 'linear', 'cosine', 'inverse-square-root'],
Mohammad's avatar
Mohammad committed
883
884
885
886
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
887
888
889
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
890
891
892
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
893
894
895
896
897
898
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
liangjing's avatar
v1  
liangjing committed
899
900
901
    group.add_argument('--lr-warmup-init', type=float, default=0.0,
                       help='Initial value for learning rate warmup. The '
                       'scheduler starts warmup from this value.')
902
    group.add_argument('--warmup', type=int, default=None,
903
                       help='Old lr warmup argument, do not use. Use one of the'
904
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
905
906
907
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
908
    group.add_argument('--override-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
909
910
911
912
913
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
914
    group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
915
916
917
918
919
920
921
922
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
923
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
924
925
926
927
928
929
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
930
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
931
                       help='Do not save current optimizer.')
932
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
933
934
935
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
936
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
937
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
938
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
939
940
941
942
943
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')
944
945
946
947
948
    group.add_argument('--no-initialization', action='store_false',
                       help='Do not perform initialization when building model, '
                       'can reduce startup time when definitely loading from a '
                       'checkpoint',
                       dest='perform_initialization')
949
950
951
    group.add_argument('--use-checkpoint-args', action='store_true',
                       help='Override any command line arguments with arguments '
                       'from the checkpoint')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
952
953
954
955
    group.add_argument('--exit-on-missing-checkpoint', action='store_true',
                       help="If '--load' is set, but checkpoint is not found "
                       "(e.g., path typo), then exit instead of random "
                       "initialization.")
Mohammad's avatar
Mohammad committed
956
957
958
959

    return parser


Mohammad's avatar
Mohammad committed
960
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
961
962
963
964
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
965
966
    group.add_argument('--bf16', action='store_true',
                       help='Run model in bfloat16 mode.')
mohammad's avatar
mohammad committed
967
968
969
970
971
972
973
974
975
976
977
978
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
979
980
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
981
982
983
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
984
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
985
986
987
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
988
989
990
    group.add_argument('--accumulate-allreduce-grads-in-fp32',
                       action='store_true',
                       help='Gradient accumulation and all-reduce in fp32.')
991
992
993
994
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
995
996
997
    return parser


Mohammad's avatar
Mohammad committed
998
def _add_distributed_args(parser):
999
1000
    group = parser.add_argument_group(title='distributed')

1001
1002
1003
1004
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
1005
1006
1007
    group.add_argument('--pipeline-model-parallel-split-rank',
                       type=int, default=None,
                       help='Rank where encoder and decoder should be split.')
1008
1009
1010
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
1011
1012
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
liangjing's avatar
v1  
liangjing committed
1013
1014
1015
1016
    group.add_argument('--overlap-p2p-communication',
                       action='store_true',
                       help='overlap pipeline parallel communication with forward and backward chunks',
                       dest='overlap_p2p_comm')
Mohammad's avatar
Mohammad committed
1017
1018
1019
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1020
1021
    group.add_argument('--distributed-timeout-minutes', type=int, default=10,
                       help='Timeout minutes for torch.distributed.')
Mohammad's avatar
Mohammad committed
1022
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
1023
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
1024
1025
                       help='which DistributedDataParallel implementation '
                       'to use.')
1026
1027
1028
1029
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
1030
1031
1032
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
1033
1034
1035
1036
    group.add_argument('--use-ring-exchange-p2p', action='store_true',
                       default=False, help='If set, use custom-built ring exchange '
                       'for p2p communications. Note that this option will require '
                       'a custom built image that support ring-exchange p2p.')
Mohammad's avatar
Mohammad committed
1037
1038
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
1039
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
1040
1041
1042
1043
1044
1045
1046
1047
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Lawrence McAfee's avatar
Lawrence McAfee committed
1048
    group.add_argument('--empty-unused-memory-level', default=0, type=int,
1049
1050
1051
1052
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
1053
    group.add_argument('--standalone-embedding-stage', action='store_true',
Lawrence McAfee's avatar
Lawrence McAfee committed
1054
1055
                       default=False, help='If set, *input* embedding layer '
                       'is placed on its own pipeline stage, without any '
Lawrence McAfee's avatar
Lawrence McAfee committed
1056
1057
                       'transformer layers. (For T5, this flag currently only '
                       'affects the encoder embedding.)')
1058
1059
    group.add_argument('--use-distributed-optimizer', action='store_true',
                       help='Use distributed optimizer.')
1060

liangjing's avatar
v1  
liangjing committed
1061
1062
1063
1064
1065
1066
1067
    group.add_argument('--rank', default=-1, type=int,
                       help='node rank for distributed training')
    group.add_argument('--world_size', type=int, default=-1,
                       help='number of nodes for distributed training')
    group.add_argument('--dist_url',
                       help='Which master node url for distributed training.')

Mohammad's avatar
Mohammad committed
1068
1069
1070
    return parser


Mohammad's avatar
Mohammad committed
1071
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
1072
1073
1074
1075
1076
1077
1078
1079
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')
liangjing's avatar
v1  
liangjing committed
1080
1081
1082
    group.add_argument('--skip-train', action='store_true',
                       default=False, help='If set, bypass the training loop, '
                       'optionally do evaluation for validation/test, and exit.')
Mohammad's avatar
Mohammad committed
1083

Mohammad's avatar
Mohammad committed
1084
1085
1086
    return parser


Mohammad's avatar
Mohammad committed
1087
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
1088
1089
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
1090
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
1091
1092
1093
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
1094
1095
1096
1097
                       'dataset2-path ... It is used with --split when a '
                       'single dataset used for all three: train, valid '
                       'and test. It is exclusive to the other '
                       '--*-data-path args')
Mohammad's avatar
Mohammad committed
1098
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
1099
1100
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
1101
1102
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
    group.add_argument('--train-data-path', nargs='*', default=None,
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
    group.add_argument('--valid-data-path', nargs='*', default=None,
                       help='Path to the validation dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
    group.add_argument('--test-data-path', nargs='*', default=None,
                       help='Path to the test dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
liangjing's avatar
v1  
liangjing committed
1118
1119
    group.add_argument('--data-cache-path', default=None,
                       help='Path to a directory to hold cached index files.')
1120

liangjing's avatar
v1  
liangjing committed
1121
1122
    group.add_argument('--vocab-size', type=int, default=None,
                       help='Size of vocab before EOD or padding.')
Mohammad's avatar
Mohammad committed
1123
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
1124
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
1125
1126
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
1127
1128
1129
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
Mohammad's avatar
Mohammad committed
1130
    group.add_argument('--seq-length', type=int, default=None,
1131
                       help='Maximum sequence length to process.')
1132
    group.add_argument('--encoder-seq-length', type=int, default=None,
1133
1134
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
1135
1136
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
1137
1138
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
1139
                       'for retriever')
1140
1141
1142
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
1143
1144
1145
1146
1147
1148
1149
1150
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
1151
1152
1153
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
1154
                                'BertWordPieceCase',
1155
                                'GPT2BPETokenizer',
1156
                                'SentencePieceTokenizer',
liangjing's avatar
v1  
liangjing committed
1157
1158
                                'GPTSentencePieceTokenizer',
                                'NullTokenizer'],
Mohammad's avatar
Mohammad committed
1159
                       help='What type of tokenizer to use.')
1160
    group.add_argument('--tokenizer-model', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1161
                       help='Sentencepiece tokenizer model.')
1162
    group.add_argument('--data-impl', type=str, default='infer',
liangjing's avatar
v1  
liangjing committed
1163
                       choices=['mmap', 'infer'],
1164
1165
1166
1167
1168
1169
1170
1171
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
1172

Mohammad's avatar
Mohammad committed
1173
1174
    return parser

Raul Puri's avatar
Raul Puri committed
1175

Mohammad's avatar
Mohammad committed
1176
1177
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
1178

Mohammad's avatar
Mohammad committed
1179
1180
1181
1182
1183
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
1184

Mohammad's avatar
Mohammad committed
1185
    return parser
Neel Kant's avatar
Neel Kant committed
1186
1187


Mostofa Patwary's avatar
Mostofa Patwary committed
1188
1189
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
1190
1191
1192

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
1193
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
1194
                        'REALM (paper default: 128)')
1195
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
1196
1197
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
1198
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
1199
1200
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
1201
1202
1203
1204
1205

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
1206
1207
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
1208
1209
1210
1211
1212

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
1213
1214
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
1215
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
1216
                       help='Whether to use one sentence documents in ICT')
1217
1218
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
1219

1220
    # training
1221
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
1222
1223
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
1224
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
1225
1226
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
1227

Neel Kant's avatar
Neel Kant committed
1228
    # faiss index
Neel Kant's avatar
Neel Kant committed
1229
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
1230
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
1231
1232
1233
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
1234
1235
1236

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
1237
1238
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
1239
    group.add_argument('--indexer-log-interval', type=int, default=1000,
1240
1241
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
1242
    return parser
1243
1244


1245
1246
def _add_vision_args(parser):
    group = parser.add_argument_group(title="vision")
1247

1248
    # general vision arguements
1249
1250
    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
1251
1252
1253
1254
    group.add_argument('--img-h', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-w', type=int, default=224,
                       help='Image height for vision classification task')
1255
1256
1257
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
1258
                       help='patch dimension')
1259
1260
1261
1262
1263
1264
1265
    group.add_argument('--classes-fraction', type=float, default=1.0,
                       help='training with fraction of classes.')
    group.add_argument('--data-per-class-fraction', type=float, default=1.0,
                       help='training with fraction of data per class.')
    group.add_argument('--no-data-sharding', action='store_false',
                       help='Disable data sharding.',
                       dest='data_sharding')
1266
1267
1268
1269
    group.add_argument('--head-lr-mult', type=float, default=1.0,
                       help='learning rate multiplier for head during finetuning')

    # pretraining type and backbone selection`
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1270
1271
    group.add_argument('--vision-pretraining', action='store_true',
                       help='flag to indicate vision pretraining')
1272
    group.add_argument('--vision-pretraining-type', type=str, default='classify',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1273
                       choices=['classify', 'inpaint', 'dino'],
1274
1275
1276
1277
1278
1279
1280
                       help='pretraining objectives')
    group.add_argument('--vision-backbone-type', type=str, default='vit',
                       choices=['vit', 'mit', 'swin'],
                       help='backbone types types')
    group.add_argument('--swin-backbone-type', type=str, default='tiny',
                       choices=['tiny', 'base', 'h3'],
                       help='pretraining objectives')
liangjing's avatar
v1  
liangjing committed
1281

1282
1283
1284
1285
1286
1287
    # inpainting arguments
    group.add_argument('--mask-type', type=str, default='random',
                       choices=['random', 'row'],
                       help='mask types')
    group.add_argument('--mask-factor', type=float, default=1.0,
                       help='mask size scaling parameter')
liangjing's avatar
v1  
liangjing committed
1288

1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
    # dino arguments
    group.add_argument('--iter-per-epoch', type=int, default=1250,
                       help='iterations per epoch')
    group.add_argument('--dino-local-img-size', type=int, default=96,
                       help='Image size for vision classification task')
    group.add_argument('--dino-local-crops-number', type=int, default=10,
                       help='Number of local crops')
    group.add_argument('--dino-head-hidden-size', type=int, default=2048,
                       help='Hidden dimension size in dino head')
    group.add_argument('--dino-bottleneck-size', type=int, default=256,
                       help='Bottle neck dimension in dino head ')
    group.add_argument('--dino-freeze-last-layer', type=float, default=1,
                       help='Freezing last layer weights')
    group.add_argument('--dino-norm-last-layer', action='store_true',
                       help='Disable Norm in last layer.')
    group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
                       help='warump teacher temperature')
    group.add_argument('--dino-teacher-temp', type=float, default=0.07,
                       help='teacher temperature')
    group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
                       help='warmup teacher temperaure epochs')
1310
1311

    return parser