Commit 66d9fcf5 authored by Myle Ott's avatar Myle Ott
Browse files

Fix tests

parent f9362e87
......@@ -16,10 +16,11 @@ import sys
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, data, models, progress_bar, tokenizer
from fairseq import criterions, progress_bar, tokenizer
def parse_args_and_arch(parser):
from fairseq import models
args = parser.parse_args()
args.model = models.arch_model_map[args.arch]
args = getattr(models, args.model).parse_arch(args)
......@@ -27,6 +28,7 @@ def parse_args_and_arch(parser):
def build_model(args, src_dict, dst_dict):
from fairseq import models
assert hasattr(models, args.model), 'Missing model type'
return getattr(models, args.model).build_model(args, src_dict, dst_dict)
......@@ -144,6 +146,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
from fairseq import data
# load model architectures and weights
states = []
for filename in filenames:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment