Commit e21901e8 authored by Myle Ott's avatar Myle Ott
Browse files

Fix interactive.py

parent 8f9dd964
...@@ -14,7 +14,7 @@ import traceback ...@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, models, tokenizer from fairseq import criterions, data, models, tokenizer
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
...@@ -117,7 +117,12 @@ def _upgrade_state_dict(state): ...@@ -117,7 +117,12 @@ def _upgrade_state_dict(state):
return state return state
def load_ensemble_for_inference(filenames, src_dict, dst_dict): def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -129,13 +134,17 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict): ...@@ -129,13 +134,17 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
args = states[0]['args'] args = states[0]['args']
args = _upgrade_args(args) args = _upgrade_args(args)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble # build ensemble
ensemble = [] ensemble = []
for state in states: for state in states:
model = build_model(args, src_dict, dst_dict) model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
ensemble.append(model) ensemble.append(model)
return ensemble return ensemble, args
def _upgrade_args(args): def _upgrade_args(args):
......
...@@ -41,7 +41,7 @@ def main(): ...@@ -41,7 +41,7 @@ def main():
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict) models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
......
...@@ -26,17 +26,13 @@ def main(): ...@@ -26,17 +26,13 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dictionaries
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang, _ = data.infer_language_pair(args.data, ['test'])
src_dict, dst_dict = data.load_dictionaries(args.data, args.source_lang, args.target_lang)
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, src_dict, dst_dict) models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict))) print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(dst_dict))) print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment