Commit 76b5ecab authored by Myle Ott's avatar Myle Ott
Browse files

Migrate all binaries to use options.parse_args_and_arch

parent cf1c64a5
...@@ -80,5 +80,5 @@ def main(args): ...@@ -80,5 +80,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_eval_lm_parser() parser = options.get_eval_lm_parser()
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
...@@ -69,6 +69,7 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -69,6 +69,7 @@ def parse_args_and_arch(parser, input_args=None):
args, _ = parser.parse_known_args(input_args) args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser. # Add model-specific args to parser.
if hasattr(args, 'arch'):
model_specific_group = parser.add_argument_group( model_specific_group = parser.add_argument_group(
'Model-specific configuration', 'Model-specific configuration',
# Only include attributes which are explicitly given as command-line # Only include attributes which are explicitly given as command-line
...@@ -78,20 +79,26 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -78,20 +79,26 @@ def parse_args_and_arch(parser, input_args=None):
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser. # Add *-specific args to parser.
if hasattr(args, 'criterion'):
CRITERION_REGISTRY[args.criterion].add_args(parser) CRITERION_REGISTRY[args.criterion].add_args(parser)
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser) OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time. # Parse a second time.
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
# Post-process args. # Post-process args.
if hasattr(args, 'lr'):
args.lr = eval_str_list(args.lr, type=float) args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int) args.update_freq = eval_str_list(args.update_freq, type=int)
if args.max_sentences_valid is None: if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
# Apply architecture configuration. # Apply architecture configuration.
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args) ARCH_CONFIG_REGISTRY[args.arch](args)
return args return args
......
...@@ -153,5 +153,5 @@ def main(args): ...@@ -153,5 +153,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser() parser = options.get_generation_parser()
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
...@@ -149,5 +149,5 @@ def main(args): ...@@ -149,5 +149,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser(interactive=True) parser = options.get_generation_parser(interactive=True)
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment