Commit 368ad0d3 authored by Mohammad's avatar Mohammad
Browse files

made size arguments optional so they can be set from input function

parent 7d75b3b5
...@@ -35,8 +35,6 @@ def parse_args(extra_args_provider=None, defaults={}): ...@@ -35,8 +35,6 @@ def parse_args(extra_args_provider=None, defaults={}):
parser = _add_validation_args(parser) parser = _add_validation_args(parser)
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
# TODO: Refactor
parser = _add_gpt2_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -54,6 +52,12 @@ def parse_args(extra_args_provider=None, defaults={}): ...@@ -54,6 +52,12 @@ def parse_args(extra_args_provider=None, defaults={}):
'defaults can only be overwritten for args with None values.' 'defaults can only be overwritten for args with None values.'
setattr(args, key, defaults[key]) setattr(args, key, defaults[key])
# Check required arguments.
_check_arg_is_not_none(args, 'num_layers')
_check_arg_is_not_none(args, 'hidden_size')
_check_arg_is_not_none(args, 'num_attention_heads')
_check_arg_is_not_none(args, 'max_position_embeddings')
# Distributed args. # Distributed args.
args.rank = int(os.getenv('RANK', '0')) args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1')) args.world_size = int(os.getenv("WORLD_SIZE", '1'))
...@@ -93,16 +97,20 @@ def _print_args(args): ...@@ -93,16 +97,20 @@ def _print_args(args):
print('---------------- end of arguments ----------------', flush=True) print('---------------- end of arguments ----------------', flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_network_size_args(parser): def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size') group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, required=True, group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, required=True, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--num-attention-heads', type=int, required=True, group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.') help='Number of transformer attention heads.')
group.add_argument('--max-position-embeddings', type=int, required=True, group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. ' help='Maximum number of position embeddings to use. '
'This is the size of position embedding.') 'This is the size of position embedding.')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
...@@ -342,28 +350,3 @@ def _add_autoresume_args(parser): ...@@ -342,28 +350,3 @@ def _add_autoresume_args(parser):
'termination signal') 'termination signal')
return parser return parser
########################################################################
def _add_gpt2_args(parser):
group = parser.add_argument_group(title='gpt2')
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='The filename containing all the shards '
'sizes for numpy data loader')
return parser
def add_data_args_(parser):
"""Train/valid/test data arguments."""
group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--data-loader', type=str, default=None,
choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
help='Which data loader to use. Default varies by model.')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment