arguments.py 23.1 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
17
18
19
20

import argparse
import os

21
import torch
22
from megatron import fused_kernels
Raul Puri's avatar
Raul Puri committed
23

24
25
def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
26
    """Parse all arguments."""
27
28
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
29

Mohammad's avatar
Mohammad committed
30
31
32
33
34
35
36
37
38
39
40
41
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Neel Kant's avatar
Neel Kant committed
42
    parser = _add_realm_args(parser)
Mohammad's avatar
Mohammad committed
43
44
45
46

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
47

Mohammad's avatar
Mohammad committed
48
    # Parse.
49
50
51
52
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
53

Mohammad's avatar
Mohammad committed
54
55
56
57
58
59
60
    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
    args.model_parallel_size = min(args.model_parallel_size, args.world_size)
    if args.rank == 0:
        print('using world size: {} and model-parallel size: {} '.format(
            args.world_size, args.model_parallel_size))
Mohammad's avatar
Mohammad committed
61

Mohammad's avatar
Mohammad committed
62
63
64
65
    # Fp16 loss scaling.
    args.dynamic_loss_scale = False
    if args.loss_scale is None:
        args.dynamic_loss_scale = True
Mohammad's avatar
Mohammad committed
66

67
68
69
70
71
72
73
74
75
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
        args.params_dtype = torch.half
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)


Mohammad's avatar
Mohammad committed
76
77
    # Set input defaults.
    for key in defaults:
Mohammad's avatar
Mohammad committed
78
79
80
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
Raul Puri's avatar
Raul Puri committed
81
        if getattr(args, key) is not None:
Raul Puri's avatar
Raul Puri committed
82
            if args.rank == 0:
Raul Puri's avatar
Raul Puri committed
83
84
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
Raul Puri's avatar
Raul Puri committed
85
86
                                               v2=getattr(args, key)),
                                               flush=True)
Raul Puri's avatar
Raul Puri committed
87
88
        else:
            setattr(args, key, defaults[key])
Mohammad's avatar
Mohammad committed
89

90
    # Check required arguments.
Mohammad's avatar
Mohammad committed
91
92
93
94
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
    for req_arg in required_args: 
        _check_arg_is_not_none(args, req_arg)
95

Mohammad's avatar
Mohammad committed
96
97
    # Checks.
    assert args.hidden_size % args.num_attention_heads == 0
Mohammad's avatar
Mohammad committed
98
99
100
101
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
102
103
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
104
105
106
107
108
109
110
111
    # Parameters sharing does not work with torch DDP.
    if (args.num_unique_layers is not None) and (args.num_layers is not None):
        assert args.num_unique_layers <= args.num_layers
        assert args.num_layers % args.num_unique_layers == 0, \
            'num-layers should be divisible by num-unique-layers.'
        if args.num_unique_layers < args.num_layers:
            assert args.DDP_impl == 'local', \
                'torch-DDP does not work with parameters sharing.'
mohammad's avatar
mohammad committed
112
113
114
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
mohammad's avatar
mohammad committed
115
116
117
118
119
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'
Mohammad's avatar
Mohammad committed
120

121
122
123
124
    # load scaled_upper_triang_masked_softmax_fusion kernel
    if args.scaled_upper_triang_masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()

125
126
127
128
    # load scaled_masked_softmax_fusion kernel
    if args.scaled_masked_softmax_fusion:
        fused_kernels.load_scaled_masked_softmax_fusion_kernel()

Mohammad's avatar
Mohammad committed
129
130
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
131
132


Mohammad's avatar
Mohammad committed
133
134
135
136
137
138
139
140
141
142
143
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
        print('-------------------- arguments --------------------', flush=True)
        str_list = []
        for arg in vars(args):
            dots = '.' * (32 - len(arg))
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
        print('---------------- end of arguments ----------------', flush=True)
Mohammad's avatar
Mohammad committed
144
145


146
147
148
149
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


Mohammad's avatar
Mohammad committed
150
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
151
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
152

153
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
154
                       help='Number of transformer layers.')
Mohammad's avatar
Mohammad committed
155
156
157
158
    group.add_argument('--num-unique-layers', type=int, default=None,
                       help='Number of unique transformer layers. '
                       '`num-layers` should be divisible by this value.')
    group.add_argument('--param-sharing-style', default='grouped',
mohammad's avatar
mohammad committed
159
                       choices=['grouped', 'spaced'],
Mohammad's avatar
Mohammad committed
160
161
162
163
164
                       help='Ordering of the shared parameters. For example, '
                       'for a `num-layers`=4 and `--num-unique-layers`=2, '
                       'we will have the following ordering for two unique '
                       'layers 1 and 2: '
                       '    grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
165
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
166
                       help='Tansformer hidden size.')
167
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
168
                       help='Number of transformer attention heads.')
169
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
170
171
172
173
174
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
175
176
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
177
178
179
180
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
181
182
183
184
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
185
    group.add_argument('--onnx-safe', type=bool, required=False,
186
                       help='Use workarounds for known problems with Torch ONNX exporter')
Mohammad's avatar
Mohammad committed
187

Mohammad's avatar
Mohammad committed
188
189
190
    return parser


Mohammad's avatar
Mohammad committed
191
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
192
193
194
195
196
197
198
199
200
201
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
                       help='Post attention dropout ptobability.')
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
202
203
204
205
206
207
208
    group.add_argument('--adam-beta1', type=float, default=0.9,
                       help='First coefficient for computing running averages of'
                       'gradient and its square')
    group.add_argument('--adam-beta2', type=float, default=0.999,
                       help='Second coefficient for computing running averages of'
                       'gradient and its square')
    group.add_argument('--adam-eps', type=float, default=1e-08,
209
                       help='Term added to the denominator to improve'
210
                       'numerical stability')
Mohammad's avatar
Mohammad committed
211
212
213

    return parser

Mohammad's avatar
Mohammad committed
214
215

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
216
217
    group = parser.add_argument_group(title='training')

Mohammad's avatar
Mohammad committed
218
    group.add_argument('--batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
219
220
221
222
223
224
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
                       'parallel size.')
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
225
226
227
228
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
                       'across model parallel group.')
Mohammad's avatar
Mohammad committed
229
230
    group.add_argument('--checkpoint-num-layers', type=int, default=1,
                       help='chunk size (number of layers) for checkpointing.')
Mohammad's avatar
Mohammad committed
231
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
232
233
234
235
236
237
238
239
240
                       help='Total number of iterations to train over all '
                       'training runs.')
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
241
242
243
    group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
                       action='store_true',
                       help='Enable fusion of query_key_value_scaling '
244
245
246
247
248
                       'time (upper diagonal) masking and softmax.')
    group.add_argument('--scaled-masked-softmax-fusion',
                       action='store_true',
                       help='Enable fusion of query_key_value_scaling '
                       'general masking and softmax.')
249
250
251
252
    group.add_argument('--bias-gelu-fusion', action='store_true',
                        help='Enable bias and gelu fusion.')
    group.add_argument('--bias-dropout-fusion', action='store_true',
                       help='Enable bias and dropout fusion.')
Mohammad's avatar
Mohammad committed
253
254
255
256

    return parser


Mohammad's avatar
Mohammad committed
257
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
258
259
260
261
262
263
264
265
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
Mohammad's avatar
Mohammad committed
266

Mohammad's avatar
Mohammad committed
267
268
269
    return parser


Mohammad's avatar
Mohammad committed
270
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
271
272
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
273
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
                       choices=['constant', 'linear', 'cosine', 'exponential'],
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--warmup', type=float, default=0.01,
                       help='Percentage of total iterations to warmup on '
                       '(.01 = 1 percent of all training iters).')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
304
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
    group.add_argument('--no-save-optim', action='store_true',
                       help='Do not save current optimizer.')
    group.add_argument('--no-save-rng', action='store_true',
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
    group.add_argument('--no-load-optim', action='store_true',
                       help='Do not load optimizer when loading checkpoint.')
    group.add_argument('--no-load-rng', action='store_true',
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')

    return parser


Mohammad's avatar
Mohammad committed
329
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
330
331
332
333
334
335
336
337
338
339
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
                       help='Scale Q * K^T by 1 / layer-number. If this flag '
                       'is set, then it will automatically set '
                       'attention-softmax-in-fp32 to true')
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
                       help='Run attention masking and softmax in fp32.')
Mohammad's avatar
Mohammad committed
340
341
    group.add_argument('--fp32-allreduce', action='store_true',
                       help='All-reduce in fp32')
Mohammad's avatar
Mohammad committed
342
343
344
345
346
347
348
349
350
351
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--min-scale', type=float, default=1,
                       help='Minimum loss scale for dynamic loss scale.')
352
353
354
355
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
356
357
358
359

    return parser


Mohammad's avatar
Mohammad committed
360
def _add_distributed_args(parser):
361
    group = parser.add_argument_group(title='distributed')
Mohammad's avatar
Mohammad committed
362

Mohammad's avatar
Mohammad committed
363
364
    group.add_argument('--model-parallel-size', type=int, default=1,
                       help='Size of the model parallel.')
Mohammad's avatar
Mohammad committed
365
366
367
368
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
369
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
370
371
372
373
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
374
375
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
                       help='If set to True, initialize_megatron() skips DDP initialization'
Boris Fomitchev's avatar
Boris Fomitchev committed
376
377
                       ' and returns function to complete it instead.'
                       'Also turns on --use-cpu-initialization flag.'
378
                       'This is for external DDP manager.' )
379
380
    group.add_argument('--use-cpu-initialization', action='store_true',
                       help='If set, affine parallel weights initialization uses CPU' )
Mohammad's avatar
Mohammad committed
381
382
383
    return parser


Mohammad's avatar
Mohammad committed
384
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
385
386
387
388
389
390
391
392
393
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
394
395
396
    return parser


Mohammad's avatar
Mohammad committed
397
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
398
399
    group = parser.add_argument_group(title='data and dataloader')

Mohammad's avatar
Mohammad committed
400
    group.add_argument('--data-path', type=str, default=None,
Mohammad's avatar
Mohammad committed
401
                       help='Path to combined dataset to split.')
Mohammad's avatar
Mohammad committed
402
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
403
404
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
405
406
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
407
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
408
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
409
410
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
Mohammad's avatar
Mohammad committed
411
    group.add_argument('--seq-length', type=int, default=None,
Mohammad's avatar
Mohammad committed
412
413
414
415
416
417
418
419
420
                       help="Maximum sequence length to process.")
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
421
422
423
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
424
                                'BertWordPieceCase',
Mohammad's avatar
Mohammad committed
425
426
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
427
428
429
430
431
432
433
434
435
436
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
437

Mohammad's avatar
Mohammad committed
438
439
    return parser

Raul Puri's avatar
Raul Puri committed
440

Mohammad's avatar
Mohammad committed
441
442
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
443

Mohammad's avatar
Mohammad committed
444
445
446
447
448
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
449

Mohammad's avatar
Mohammad committed
450
    return parser
Neel Kant's avatar
Neel Kant committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470


def _add_realm_args(parser):
    group = parser.add_argument_group(title='realm')

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
                       help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
                       help='Probability of keeping query in block for ICT dataset')
Neel Kant's avatar
Neel Kant committed
471
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
472
473
                       help='Whether to use one sentence documents in ICT')

474
475
476
477
    # training
    group.add_argument('--report-topk-accuracies', nargs='+', default=[],
                       help="Which top-k accuracies to report (e.g. '1 5 20')")

Neel Kant's avatar
Neel Kant committed
478
479
480
    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
Neel Kant's avatar
Neel Kant committed
481
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
482
                       help='Where to save/load BlockData to/from')
Neel Kant's avatar
Neel Kant committed
483
484
485
486
487
488

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
                       help='How large of batches to use when doing indexing jobs')
    group.add_argument('--indexer-log-interval', type=int, default=1000,
                       help='After how many batches should the indexer report progress')
Neel Kant's avatar
Neel Kant committed
489
490
    return parser