Commit 8df95dcc authored by Myle Ott's avatar Myle Ott
Browse files

Upgrade args with max_source_positions and max_target_positions

parent 5ef59abd
......@@ -127,6 +127,7 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
args = states[0]['args']
args = _upgrade_args(args)
# build ensemble
ensemble = []
......@@ -137,6 +138,13 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
return ensemble
def _upgrade_args(args):
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
return args
def prepare_sample(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment