"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cbc2ec8f44449cbc888256499d71bb6d7196aaa2"
Commit 0add50c2 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

allowing sharded dataset (#696)



Summary:
Co-authored-by: default avatarmyleott <myleott@fb.com>

Changing `data` to be `str` with colon separated list for loading sharded datasets. This change is useful for loading large datasets that cannot fit into, memory. The large dataset can be sharded and then each shard is loaded in one epoch in roudrobin manner.

For example, if there are `5` shards of data and `10` epochs then the shards will be iterated upon `[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]`.

myleott We need to look into `translation.py` as it currently already expects a list and then concats the datasets.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/696

Differential Revision: D15214049

fbshipit-source-id: 03e43a7b69c7aefada2ca668abf1eac1969fe013
parent 57da383c
...@@ -79,11 +79,12 @@ class EpochBatchIterator(object): ...@@ -79,11 +79,12 @@ class EpochBatchIterator(object):
num_workers (int, optional): how many subprocesses to use for data num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process loading. 0 means the data will be loaded in the main process
(default: 0). (default: 0).
epoch (int, optional): The epoch to start the iterator from.
""" """
def __init__( def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
num_workers=0, num_workers=0, epoch=0,
): ):
assert isinstance(dataset, torch.utils.data.Dataset) assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset self.dataset = dataset
...@@ -94,7 +95,7 @@ class EpochBatchIterator(object): ...@@ -94,7 +95,7 @@ class EpochBatchIterator(object):
self.shard_id = shard_id self.shard_id = shard_id
self.num_workers = num_workers self.num_workers = num_workers
self.epoch = 0 self.epoch = epoch
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
......
...@@ -42,7 +42,8 @@ class CrossLingualLMTask(FairseqTask): ...@@ -42,7 +42,8 @@ class CrossLingualLMTask(FairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', help='path to data directory') parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('--tokens-per-sample', default=512, type=int, parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments' help='max number of total tokens over all segments'
' per sample') ' per sample')
...@@ -106,12 +107,16 @@ class CrossLingualLMTask(FairseqTask): ...@@ -106,12 +107,16 @@ class CrossLingualLMTask(FairseqTask):
return cls(args, dictionary) return cls(args, dictionary)
def _load_single_lang_dataset(self, split): def _load_single_lang_dataset(self, split, epoch):
loaded_datasets = [] loaded_datasets = []
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k) path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
...@@ -124,7 +129,7 @@ class CrossLingualLMTask(FairseqTask): ...@@ -124,7 +129,7 @@ class CrossLingualLMTask(FairseqTask):
if k > 0: if k > 0:
break break
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
# Since we append each block with the classification_token, # Since we append each block with the classification_token,
# we need to effectively create blocks of length # we need to effectively create blocks of length
...@@ -136,7 +141,7 @@ class CrossLingualLMTask(FairseqTask): ...@@ -136,7 +141,7 @@ class CrossLingualLMTask(FairseqTask):
) )
) )
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
if len(loaded_datasets) == 1: if len(loaded_datasets) == 1:
dataset = loaded_datasets[0] dataset = loaded_datasets[0]
...@@ -147,7 +152,7 @@ class CrossLingualLMTask(FairseqTask): ...@@ -147,7 +152,7 @@ class CrossLingualLMTask(FairseqTask):
return dataset, sizes return dataset, sizes
def load_dataset(self, split, combine=False, **kwargs): def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
split (str): name of the split (e.g., train, valid, test) split (str): name of the split (e.g., train, valid, test)
...@@ -162,7 +167,7 @@ class CrossLingualLMTask(FairseqTask): ...@@ -162,7 +167,7 @@ class CrossLingualLMTask(FairseqTask):
# Datasets are expected to be in "split.lang" format (Eg: train.en) # Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang) language_split = '{}.{}'.format(split, lang)
block_dataset, sizes = self._load_single_lang_dataset(split=language_split) block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch)
dataset_map[lang] = MaskedLMDataset( dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset, dataset=block_dataset,
...@@ -182,6 +187,6 @@ class CrossLingualLMTask(FairseqTask): ...@@ -182,6 +187,6 @@ class CrossLingualLMTask(FairseqTask):
dataset_map, default_key=self.default_key dataset_map, default_key=self.default_key
) )
print('| {} {} {} examples'.format( print('| {} {} {} examples'.format(
self.args.data, split, len(self.datasets[split]) self.args.data.split(':')[epoch], split, len(self.datasets[split])
) )
) )
\ No newline at end of file
...@@ -92,7 +92,7 @@ class FairseqTask(object): ...@@ -92,7 +92,7 @@ class FairseqTask(object):
def get_batch_iterator( def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None, self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0,
): ):
""" """
Get an iterator that yields batches of data from the given dataset. Get an iterator that yields batches of data from the given dataset.
...@@ -118,6 +118,7 @@ class FairseqTask(object): ...@@ -118,6 +118,7 @@ class FairseqTask(object):
num_workers (int, optional): how many subprocesses to use for data num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process loading. 0 means the data will be loaded in the main process
(default: 0). (default: 0).
epoch (int, optional): The epoch to start the iterator from.
Returns: Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
...@@ -149,6 +150,7 @@ class FairseqTask(object): ...@@ -149,6 +150,7 @@ class FairseqTask(object):
num_shards=num_shards, num_shards=num_shards,
shard_id=shard_id, shard_id=shard_id,
num_workers=num_workers, num_workers=num_workers,
epoch=epoch,
) )
def build_model(self, args): def build_model(self, args):
......
...@@ -104,7 +104,9 @@ class LanguageModelingTask(FairseqTask): ...@@ -104,7 +104,9 @@ class LanguageModelingTask(FairseqTask):
dictionary = None dictionary = None
output_dictionary = None output_dictionary = None
if args.data: if args.data:
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) paths = args.data.split(':')
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary))) print('| dictionary: {} types'.format(len(dictionary)))
output_dictionary = dictionary output_dictionary = dictionary
if args.output_dictionary_size >= 0: if args.output_dictionary_size >= 0:
...@@ -136,7 +138,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -136,7 +138,7 @@ class LanguageModelingTask(FairseqTask):
return model return model
def load_dataset(self, split, combine=False, **kwargs): def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
...@@ -145,9 +147,13 @@ class LanguageModelingTask(FairseqTask): ...@@ -145,9 +147,13 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets = [] loaded_datasets = []
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k) path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
...@@ -160,7 +166,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -160,7 +166,7 @@ class LanguageModelingTask(FairseqTask):
if k > 0: if k > 0:
break break
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
loaded_datasets.append( loaded_datasets.append(
TokenBlockDataset( TokenBlockDataset(
...@@ -170,7 +176,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -170,7 +176,7 @@ class LanguageModelingTask(FairseqTask):
) )
) )
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
if not combine: if not combine:
break break
......
...@@ -135,7 +135,9 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -135,7 +135,9 @@ class MultilingualTranslationTask(FairseqTask):
# load dictionaries # load dictionaries
dicts = OrderedDict() dicts = OrderedDict()
for lang in sorted_langs: for lang in sorted_langs:
dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang))) paths = args.data.split(':')
assert len(paths) > 0
dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
if len(dicts) > 0: if len(dicts) > 0:
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
...@@ -185,11 +187,15 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -185,11 +187,15 @@ class MultilingualTranslationTask(FairseqTask):
new_tgt_bos=new_tgt_bos, new_tgt_bos=new_tgt_bos,
) )
def load_dataset(self, split, **kwargs): def load_dataset(self, split, epoch=0, **kwargs):
"""Load a dataset split.""" """Load a dataset split."""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang): def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
elif not self.args.raw_text and IndexedDataset.exists(filename): elif not self.args.raw_text and IndexedDataset.exists(filename):
...@@ -210,17 +216,17 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -210,17 +216,17 @@ class MultilingualTranslationTask(FairseqTask):
for lang_pair in self.args.lang_pairs: for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src): if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src): elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else: else:
continue continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| {} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair]))) print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0: if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
def language_pair_dataset(lang_pair): def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
......
...@@ -132,14 +132,18 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -132,14 +132,18 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
return cls(args, dicts, training) return cls(args, dicts, training)
def load_dataset(self, split, **kwargs): def load_dataset(self, split, epoch=0, **kwargs):
"""Load a dataset split.""" """Load a dataset split."""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang): def split_exists(split, src, tgt, lang):
if src is not None: if src is not None:
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
else: else:
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt)) filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
elif not self.args.raw_text and IndexedDataset.exists(filename): elif not self.args.raw_text and IndexedDataset.exists(filename):
...@@ -162,16 +166,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -162,16 +166,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
for lang_pair in self.args.lang_pairs: for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src): if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src): elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else: else:
continue continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair]))) print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0: if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
# back translation datasets # back translation datasets
backtranslate_datasets = {} backtranslate_datasets = {}
...@@ -179,8 +183,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -179,8 +183,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
for lang_pair in self.args.lang_pairs: for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt): if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt)) filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
dataset = indexed_dataset(filename, self.dicts[tgt]) dataset = indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset( lang_pair_dataset_tgt = LanguagePairDataset(
dataset, dataset,
...@@ -216,7 +220,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -216,7 +220,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
).collater, ).collater,
) )
print('| backtranslate-{}: {} {} {} examples'.format( print('| backtranslate-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]), tgt, data_path, split, len(backtranslate_datasets[lang_pair]),
)) ))
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]
...@@ -227,7 +231,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -227,7 +231,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
_, tgt = lang_pair.split('-') _, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt): if not split_exists(split, tgt, None, tgt):
continue continue
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt)) filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt]) tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset( noising_dataset = NoisingDataset(
...@@ -255,7 +259,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -255,7 +259,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
tgt_lang=tgt, tgt_lang=tgt,
) )
print('| denoising-{}: {} {} {} examples'.format( print('| denoising-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(noising_datasets[lang_pair]), tgt, data_path, split, len(noising_datasets[lang_pair]),
)) ))
def language_pair_dataset(lang_pair): def language_pair_dataset(lang_pair):
......
...@@ -48,7 +48,8 @@ class TranslationTask(FairseqTask): ...@@ -48,7 +48,8 @@ class TranslationTask(FairseqTask):
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off # fmt: off
parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)') parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language') help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
...@@ -84,19 +85,17 @@ class TranslationTask(FairseqTask): ...@@ -84,19 +85,17 @@ class TranslationTask(FairseqTask):
args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target) args.left_pad_target = options.eval_bool(args.left_pad_target)
# upgrade old checkpoints paths = args.data.split(':')
if isinstance(args.data, str): assert len(paths) > 0
args.data = [args.data]
# find language pair automatically # find language pair automatically
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0]) args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly') raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries # load dictionaries
src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang))) src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang))) tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad() assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos() assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk() assert src_dict.unk() == tgt_dict.unk()
...@@ -105,12 +104,15 @@ class TranslationTask(FairseqTask): ...@@ -105,12 +104,15 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict) return cls(args, src_dict, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs): def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
split (str): name of the split (e.g., train, valid, test) split (str): name of the split (e.g., train, valid, test)
""" """
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang, data_path): def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
...@@ -133,29 +135,28 @@ class TranslationTask(FairseqTask): ...@@ -133,29 +135,28 @@ class TranslationTask(FairseqTask):
src_datasets = [] src_datasets = []
tgt_datasets = [] tgt_datasets = []
for dk, data_path in enumerate(self.args.data): for k in itertools.count():
for k in itertools.count(): split_k = split + (str(k) if k > 0 else '')
split_k = split + (str(k) if k > 0 else '')
# infer langcode
# infer langcode src, tgt = self.args.source_lang, self.args.target_lang
src, tgt = self.args.source_lang, self.args.target_lang if split_exists(split_k, src, tgt, src, data_path):
if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path):
elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) else:
if k > 0:
break
else: else:
if k > 0 or dk > 0: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))
if not combine: if not combine:
break break
assert len(src_datasets) == len(tgt_datasets) assert len(src_datasets) == len(tgt_datasets)
......
...@@ -68,10 +68,13 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -68,10 +68,13 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches] [p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self): def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
train.load_checkpoint(self.args_mock, trainer, epoch_itr) with patch('train.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50) self.assertEqual(epoch_itr.iterations_in_epoch, 50)
...@@ -86,7 +89,8 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -86,7 +89,8 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
train.load_checkpoint(self.args_mock, trainer, epoch_itr) with patch('train.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.epoch, 3)
...@@ -98,7 +102,7 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -98,7 +102,7 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False self.patches['os.path.isfile'].return_value = False
train.load_checkpoint(self.args_mock, trainer, epoch_itr) train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1) self.assertEqual(epoch_itr.epoch, 1)
......
...@@ -44,7 +44,9 @@ def main(args, init_distributed=False): ...@@ -44,7 +44,9 @@ def main(args, init_distributed=False):
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load dataset splits # Load dataset splits
load_dataset_splits(args, task) task.load_dataset(args.train_subset, combine=True, epoch=0)
for valid_sub_split in args.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=True, epoch=0)
# Build model and criterion # Build model and criterion
model = task.build_model(args) model = task.build_model(args)
...@@ -64,15 +66,16 @@ def main(args, init_distributed=False): ...@@ -64,15 +66,16 @@ def main(args, init_distributed=False):
args.max_sentences, args.max_sentences,
)) ))
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
# Initialize dataloader # Initialize dataloader
epoch_itr = task.get_batch_iterator( epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset), dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions( max_positions=max_positions,
task.max_positions(),
model.max_positions(),
),
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple, required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed, seed=args.seed,
...@@ -82,7 +85,7 @@ def main(args, init_distributed=False): ...@@ -82,7 +85,7 @@ def main(args, init_distributed=False):
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr) load_checkpoint(args, trainer, epoch_itr, max_positions, task)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -105,10 +108,34 @@ def main(args, init_distributed=False): ...@@ -105,10 +108,34 @@ def main(args, init_distributed=False):
# save checkpoint # save checkpoint
if epoch_itr.epoch % args.save_interval == 0: if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if len(args.data.split(":")) == 1:
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def train(args, trainer, task, epoch_itr): def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Update parameters every N batches # Update parameters every N batches
...@@ -335,9 +362,8 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -335,9 +362,8 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
os.remove(old_chk) os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr): def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match.""" """Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir # Only rank 0 should attempt to create the required dir
if args.distributed_rank == 0: if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
...@@ -351,7 +377,14 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -351,7 +377,14 @@ def load_checkpoint(args, trainer, epoch_itr):
eval(args.optimizer_overrides)) eval(args.optimizer_overrides))
if extra_state is not None: if extra_state is not None:
# replay train iterator to match checkpoint # replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator']) epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr.load_state_dict(epoch_itr_state)
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates())) checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
...@@ -366,19 +399,6 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -366,19 +399,6 @@ def load_checkpoint(args, trainer, epoch_itr):
return False return False
def load_dataset_splits(args, task):
task.load_dataset(args.train_subset, combine=True)
for split in args.valid_subset.split(','):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e
def distributed_main(i, args, start_rank=0): def distributed_main(i, args, start_rank=0):
args.device_id = i args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn if args.distributed_rank is None: # torch.multiprocessing.spawn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment