Commit e21901e8 authored by Myle Ott's avatar Myle Ott
Browse files

Fix interactive.py

parent 8f9dd964
......@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, models, tokenizer
from fairseq import criterions, data, models, tokenizer
def parse_args_and_arch(parser):
......@@ -117,7 +117,12 @@ def _upgrade_state_dict(state):
return state
def load_ensemble_for_inference(filenames, src_dict, dst_dict):
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
# load model architectures and weights
states = []
for filename in filenames:
......@@ -129,13 +134,17 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
args = states[0]['args']
args = _upgrade_args(args)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []
for state in states:
model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble
return ensemble, args
def _upgrade_args(args):
......
......@@ -41,7 +41,7 @@ def main():
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
......
......@@ -26,17 +26,13 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dictionaries
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang, _ = data.infer_language_pair(args.data, ['test'])
src_dict, dst_dict = data.load_dictionaries(args.data, args.source_lang, args.target_lang)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, src_dict, dst_dict)
models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(dst_dict)))
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
# Optimize ensemble for generation
for model in models:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment