Commit 8943fc78 authored by Myle Ott's avatar Myle Ott
Browse files

Fix language inference in generate.py

parent 84b82dc6
......@@ -18,30 +18,32 @@ from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, load_splits, src=None, dst=None):
"""Loads the train, valid, and test sets from the specified folder
and check that training files exist."""
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
def find_language_pair(files):
for split in load_splits:
for filename in files:
parts = filename.split('.')
if parts[0] == 'train' and parts[-1] == 'idx':
if parts[0] == split and parts[-1] == 'idx':
return parts[1].split('-')
def train_file_exists(src, dst):
filename = 'train.{0}-{1}.{0}.idx'.format(src, dst)
def split_exists(split, src, dst):
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
return os.path.exists(os.path.join(path, filename))
if src is None and dst is None:
# find language pair automatically
src, dst = find_language_pair(os.listdir(path))
elif train_file_exists(src, dst):
# check for src-dst langcode
pass
elif train_file_exists(dst, src):
# check for dst-src langcode
if not split_exists(load_splits[0], src, dst):
# try reversing src and dst
src, dst = dst, src
else:
raise ValueError('training file not found for {}-{}'.format(src, dst))
for split in load_splits:
if not split_exists(load_splits[0], src, dst):
raise ValueError('Data split not found: {}-{} ({})'.format(
src, dst, split))
dataset = load(path, load_splits, src, dst)
return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment