Commit 386847ee authored by Myle Ott's avatar Myle Ott
Browse files

Generalize eval_str_list

parent 09379ad8
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import options, utils
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, AdaptiveSoftmax from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, AdaptiveSoftmax
...@@ -139,8 +139,7 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -139,8 +139,7 @@ class FConvLanguageModel(FairseqLanguageModel):
max_positions=args.max_target_positions, max_positions=args.max_target_positions,
share_embed=False, share_embed=False,
positional_embeddings=False, positional_embeddings=False,
adaptive_softmax_cutoff=list( adaptive_softmax_cutoff=options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
map(int, args.adaptive_softmax_cutoff.split(','))) if args.adaptive_softmax_cutoff else None,
normalization_constant=args.normalization_constant, normalization_constant=args.normalization_constant,
) )
return FConvLanguageModel(decoder) return FConvLanguageModel(decoder)
......
...@@ -34,13 +34,15 @@ def get_generation_parser(interactive=False): ...@@ -34,13 +34,15 @@ def get_generation_parser(interactive=False):
return parser return parser
def _eval_float_list(x): def eval_str_list(x, type=float):
if x is None:
return None
if isinstance(x, str): if isinstance(x, str):
x = eval(x) x = eval(x)
try: try:
return list(x) return list(map(type, x))
except: except:
return [float(x)] return [type(x)]
def get_eval_lm_parser(): def get_eval_lm_parser():
...@@ -75,8 +77,8 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -75,8 +77,8 @@ def parse_args_and_arch(parser, input_args=None):
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
# Post-process args. # Post-process args.
args.lr = _eval_float_list(args.lr) args.lr = eval_str_list(args.lr, type=float)
args.update_freq = _eval_float_list(args.update_freq) args.update_freq = eval_str_list(args.update_freq, type=int)
if args.max_sentences_valid is None: if args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment