"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "841504bb1a32321641ec19db3cc65376af9e2bd7"
Commit 76b5ecab authored by Myle Ott's avatar Myle Ott
Browse files

Migrate all binaries to use options.parse_args_and_arch

parent cf1c64a5
...@@ -80,5 +80,5 @@ def main(args): ...@@ -80,5 +80,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_eval_lm_parser() parser = options.get_eval_lm_parser()
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
...@@ -69,30 +69,37 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -69,30 +69,37 @@ def parse_args_and_arch(parser, input_args=None):
args, _ = parser.parse_known_args(input_args) args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser. # Add model-specific args to parser.
model_specific_group = parser.add_argument_group( if hasattr(args, 'arch'):
'Model-specific configuration', model_specific_group = parser.add_argument_group(
# Only include attributes which are explicitly given as command-line 'Model-specific configuration',
# arguments or which have default values. # Only include attributes which are explicitly given as command-line
argument_default=argparse.SUPPRESS, # arguments or which have default values.
) argument_default=argparse.SUPPRESS,
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) )
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser. # Add *-specific args to parser.
CRITERION_REGISTRY[args.criterion].add_args(parser) if hasattr(args, 'criterion'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser) CRITERION_REGISTRY[args.criterion].add_args(parser)
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time. # Parse a second time.
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
# Post-process args. # Post-process args.
args.lr = eval_str_list(args.lr, type=float) if hasattr(args, 'lr'):
args.update_freq = eval_str_list(args.update_freq, type=int) args.lr = eval_str_list(args.lr, type=float)
if args.max_sentences_valid is None: if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
# Apply architecture configuration. # Apply architecture configuration.
ARCH_CONFIG_REGISTRY[args.arch](args) if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
return args return args
......
...@@ -153,5 +153,5 @@ def main(args): ...@@ -153,5 +153,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser() parser = options.get_generation_parser()
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
...@@ -149,5 +149,5 @@ def main(args): ...@@ -149,5 +149,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser(interactive=True) parser = options.get_generation_parser(interactive=True)
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment