Commit 0863ea68 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Add multi-dataset loading to multilingual_translation

Summary: Similar to TranslationTask, we want to enable multilingual translation task to be able to load 'train{k}' datasets from data-bin folder.

Reviewed By: lematt1991

Differential Revision: D15363481

fbshipit-source-id: 5fed7be19383023b792ed2fd38e655cbcecc8b90
parent 861dd2b7
......@@ -20,6 +20,7 @@ from fairseq.data import (
indexed_dataset,
)
from fairseq.models import FairseqMultiModel
from fairseq.tasks.translation import load_langpair_dataset
from . import FairseqTask, register_task
......@@ -84,6 +85,8 @@ class MultilingualTranslationTask(FairseqTask):
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'],
metavar='SRCTGT',
help='replace beginning-of-sentence in source sentence with source or target '
......@@ -196,40 +199,19 @@ class MultilingualTranslationTask(FairseqTask):
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
src_datasets, tgt_datasets = {}, {}
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dicts[tgt])
print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
langpair_dataset = load_langpair_dataset(
data_path, split, src, self.dicts[src], tgt, self.dicts[tgt],
combine=True, dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
return self.alter_dataset_langtok(
LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
),
langpair_dataset,
src_eos=self.dicts[tgt].eos(),
src_lang=src,
tgt_lang=tgt,
......
......@@ -19,6 +19,64 @@ from fairseq.data import (
from . import FairseqTask, register_task
def load_langpair_dataset(
data_path, split,
src, src_dict,
tgt, tgt_dict,
combine, dataset_impl, upsample_primary,
left_pad_source, left_pad_target, max_source_positions, max_target_positions,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=dataset_impl,
fix_lua_indexing=True, dictionary=src_dict))
tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl,
fix_lua_indexing=True, dictionary=tgt_dict))
print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
return LanguagePairDataset(
src_dataset, src_dataset.sizes, src_dict,
tgt_dataset, tgt_dataset.sizes, tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
max_source_positions=max_source_positions,
max_target_positions=max_target_positions,
)
@register_task('translation')
class TranslationTask(FairseqTask):
"""
......@@ -117,51 +175,13 @@ class TranslationTask(FairseqTask):
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.src_dict))
tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.tgt_dict))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = self.args.upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
self.datasets[split] = load_langpair_dataset(
data_path, split, src, self.src_dict, tgt, self.tgt_dict,
combine=combine, dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment