Commit 91c78477 authored by taineleau's avatar taineleau Committed by Facebook Github Bot
Browse files

add ConcatDataset support for XLM

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/684

Differential Revision: D15154631

Pulled By: myleott

fbshipit-source-id: 5e7dd9651d9ed239b60c51b9a11d08c80307d3ba
parent ff74ca94
...@@ -17,6 +17,7 @@ from . import FairseqDataset, data_utils ...@@ -17,6 +17,7 @@ from . import FairseqDataset, data_utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset from fairseq.data.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset from fairseq.data.token_block_dataset import TokenBlockDataset
from fairseq.data.concat_dataset import ConcatDataset
class MaskedLMDataset(FairseqDataset): class MaskedLMDataset(FairseqDataset):
...@@ -77,8 +78,10 @@ class MaskedLMDataset(FairseqDataset): ...@@ -77,8 +78,10 @@ class MaskedLMDataset(FairseqDataset):
# Make sure the input datasets are the ones supported # Make sure the input datasets are the ones supported
assert ( assert (
isinstance(dataset, TokenBlockDataset) or isinstance(dataset, TokenBlockDataset) or
isinstance(dataset, BlockPairDataset) isinstance(dataset, BlockPairDataset) or
), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset" isinstance(dataset, ConcatDataset)
), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " \
"ConcatDataset"
self.dataset = dataset self.dataset = dataset
self.sizes = np.array(sizes) self.sizes = np.array(sizes)
...@@ -355,4 +358,4 @@ class MaskedLMDataset(FairseqDataset): ...@@ -355,4 +358,4 @@ class MaskedLMDataset(FairseqDataset):
return getattr(self.dataset, "supports_prefetch", False) return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices): def prefetch(self, indices):
self.dataset.prefetch(indices) self.dataset.prefetch(indices)
\ No newline at end of file
...@@ -5,14 +5,18 @@ ...@@ -5,14 +5,18 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools
import os import os
from collections import OrderedDict from collections import OrderedDict
import numpy as np
from fairseq import tokenizer from fairseq import tokenizer
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import ( from fairseq.data import (
ConcatDataset,
IndexedCachedDataset, IndexedCachedDataset,
IndexedDataset, IndexedDataset,
IndexedRawTextDataset, IndexedRawTextDataset,
...@@ -102,19 +106,12 @@ class CrossLingualLMTask(FairseqTask): ...@@ -102,19 +106,12 @@ class CrossLingualLMTask(FairseqTask):
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split, combine=False): def _load_single_lang_dataset(self, split):
"""Load a given dataset split. loaded_datasets = []
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
for lang in self.langs2id.keys(): for k in itertools.count():
if self.default_key is None: split_k = split + (str(k) if k > 0 else '')
self.default_key = lang path = os.path.join(self.args.data, split_k)
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang)
path = os.path.join(self.args.data, language_split)
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
...@@ -124,23 +121,52 @@ class CrossLingualLMTask(FairseqTask): ...@@ -124,23 +121,52 @@ class CrossLingualLMTask(FairseqTask):
else: else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True) ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format( if k > 0:
language_split, self.args.data)) break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
# Since we append each block with the classification_token, # Since we append each block with the classification_token,
# we need to effectively create blocks of length # we need to effectively create blocks of length
# tokens_per_sample-1 # tokens_per_sample-1
block_dataset = TokenBlockDataset( loaded_datasets.append(
dataset=ds, TokenBlockDataset(
sizes=ds.sizes, ds, ds.sizes, self.args.tokens_per_sample - 1,
block_size=self.args.tokens_per_sample-1, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
pad=self.dictionary.pad(), )
eos=self.dictionary.eos()
) )
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
return dataset, sizes
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
for lang in self.langs2id.keys():
if self.default_key is None:
self.default_key = lang
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang)
block_dataset, sizes = self._load_single_lang_dataset(split=language_split)
dataset_map[lang] = MaskedLMDataset( dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset, dataset=block_dataset,
sizes=block_dataset.sizes, sizes=sizes,
vocab=self.dictionary, vocab=self.dictionary,
pad_idx=self.dictionary.pad(), pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(), mask_idx=self.dictionary.mask(),
...@@ -158,4 +184,4 @@ class CrossLingualLMTask(FairseqTask): ...@@ -158,4 +184,4 @@ class CrossLingualLMTask(FairseqTask):
print('| {} {} {} examples'.format( print('| {} {} {} examples'.format(
self.args.data, split, len(self.datasets[split]) self.args.data, split, len(self.datasets[split])
) )
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment