Commit 386847ee authored by Myle Ott's avatar Myle Ott
Browse files

Generalize eval_str_list

parent 09379ad8
......@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq import options, utils
from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution, AdaptiveSoftmax
......@@ -139,8 +139,7 @@ class FConvLanguageModel(FairseqLanguageModel):
max_positions=args.max_target_positions,
share_embed=False,
positional_embeddings=False,
adaptive_softmax_cutoff=list(
map(int, args.adaptive_softmax_cutoff.split(','))) if args.adaptive_softmax_cutoff else None,
adaptive_softmax_cutoff=options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
normalization_constant=args.normalization_constant,
)
return FConvLanguageModel(decoder)
......
......@@ -34,13 +34,15 @@ def get_generation_parser(interactive=False):
return parser
def _eval_float_list(x):
def eval_str_list(x, type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(x)
return list(map(type, x))
except:
return [float(x)]
return [type(x)]
def get_eval_lm_parser():
......@@ -75,8 +77,8 @@ def parse_args_and_arch(parser, input_args=None):
args = parser.parse_args(input_args)
# Post-process args.
args.lr = _eval_float_list(args.lr)
args.update_freq = _eval_float_list(args.update_freq)
args.lr = eval_str_list(args.lr, type=float)
args.update_freq = eval_str_list(args.update_freq, type=int)
if args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment