Commit 76b5ecab authored by Myle Ott's avatar Myle Ott
Browse files

Migrate all binaries to use options.parse_args_and_arch

parent cf1c64a5
......@@ -80,5 +80,5 @@ def main(args):
if __name__ == '__main__':
parser = options.get_eval_lm_parser()
args = parser.parse_args()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -69,30 +69,37 @@ def parse_args_and_arch(parser, input_args=None):
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
if hasattr(args, 'arch'):
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser.
CRITERION_REGISTRY[args.criterion].add_args(parser)
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'criterion'):
CRITERION_REGISTRY[args.criterion].add_args(parser)
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time.
args = parser.parse_args(input_args)
# Post-process args.
args.lr = eval_str_list(args.lr, type=float)
args.update_freq = eval_str_list(args.update_freq, type=int)
if args.max_sentences_valid is None:
if hasattr(args, 'lr'):
args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
# Apply architecture configuration.
ARCH_CONFIG_REGISTRY[args.arch](args)
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
return args
......
......@@ -153,5 +153,5 @@ def main(args):
if __name__ == '__main__':
parser = options.get_generation_parser()
args = parser.parse_args()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -149,5 +149,5 @@ def main(args):
if __name__ == '__main__':
parser = options.get_generation_parser(interactive=True)
args = parser.parse_args()
args = options.parse_args_and_arch(parser)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment