Commit a4fe8c99 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Add back secondary set

parent 535ca991
......@@ -45,7 +45,7 @@ class TranslationTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='path to data directory')
parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
......@@ -80,13 +80,13 @@ class TranslationTask(FairseqTask):
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
src_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
......@@ -102,8 +102,8 @@ class TranslationTask(FairseqTask):
split (str): name of the split (e.g., train, valid, test)
"""
def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
......@@ -120,29 +120,35 @@ class TranslationTask(FairseqTask):
src_datasets = []
tgt_datasets = []
data_paths = self.args.data
for data_path in data_paths:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment