Commit a5e49364 authored by Myle Ott's avatar Myle Ott
Browse files

Support integer learning rates

parent 9f1b37dd
...@@ -34,6 +34,15 @@ def get_generation_parser(interactive=False): ...@@ -34,6 +34,15 @@ def get_generation_parser(interactive=False):
return parser return parser
def _eval_float_list(x):
if isinstance(x, str):
x = eval(x)
try:
return list(x)
except:
return [float(x)]
def parse_args_and_arch(parser, input_args=None): def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
...@@ -59,9 +68,8 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -59,9 +68,8 @@ def parse_args_and_arch(parser, input_args=None):
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
# Post-process args. # Post-process args.
args.lr = eval(args.lr) args.lr = _eval_float_list(args.lr)
args.lr = [args.lr] if isinstance(args.lr, float) else list(args.lr) args.update_freq = _eval_float_list(args.update_freq)
args.update_freq = list(map(float, args.update_freq.split(',')))
if args.max_sentences_valid is None: if args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment