Commit 91c78477 authored by taineleau's avatar taineleau Committed by Facebook Github Bot
Browse files

add ConcatDataset support for XLM

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/684

Differential Revision: D15154631

Pulled By: myleott

fbshipit-source-id: 5e7dd9651d9ed239b60c51b9a11d08c80307d3ba
parent ff74ca94
......@@ -17,6 +17,7 @@ from . import FairseqDataset, data_utils
from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset
from fairseq.data.concat_dataset import ConcatDataset
class MaskedLMDataset(FairseqDataset):
......@@ -77,8 +78,10 @@ class MaskedLMDataset(FairseqDataset):
# Make sure the input datasets are the ones supported
assert (
isinstance(dataset, TokenBlockDataset) or
isinstance(dataset, BlockPairDataset)
), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset"
isinstance(dataset, BlockPairDataset) or
isinstance(dataset, ConcatDataset)
), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " \
"ConcatDataset"
self.dataset = dataset
self.sizes = np.array(sizes)
......@@ -355,4 +358,4 @@ class MaskedLMDataset(FairseqDataset):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
self.dataset.prefetch(indices)
\ No newline at end of file
......@@ -5,14 +5,18 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
from collections import OrderedDict
import numpy as np
from fairseq import tokenizer
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import (
ConcatDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
......@@ -102,19 +106,12 @@ class CrossLingualLMTask(FairseqTask):
return cls(args, dictionary)
def load_dataset(self, split, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
def _load_single_lang_dataset(self, split):
loaded_datasets = []
for lang in self.langs2id.keys():
if self.default_key is None:
self.default_key = lang
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang)
path = os.path.join(self.args.data, language_split)
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
......@@ -124,23 +121,52 @@ class CrossLingualLMTask(FairseqTask):
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(
language_split, self.args.data))
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
# Since we append each block with the classification_token,
# we need to effectively create blocks of length
# tokens_per_sample-1
block_dataset = TokenBlockDataset(
dataset=ds,
sizes=ds.sizes,
block_size=self.args.tokens_per_sample-1,
pad=self.dictionary.pad(),
eos=self.dictionary.eos()
loaded_datasets.append(
TokenBlockDataset(
ds, ds.sizes, self.args.tokens_per_sample - 1,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
)
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
return dataset, sizes
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
for lang in self.langs2id.keys():
if self.default_key is None:
self.default_key = lang
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang)
block_dataset, sizes = self._load_single_lang_dataset(split=language_split)
dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset,
sizes=block_dataset.sizes,
sizes=sizes,
vocab=self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(),
......@@ -158,4 +184,4 @@ class CrossLingualLMTask(FairseqTask):
print('| {} {} {} examples'.format(
self.args.data, split, len(self.datasets[split])
)
)
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment