Commit 8f9dd964 authored by Myle Ott's avatar Myle Ott
Browse files

Improvements to data loader

parent 97d7fcb9
......@@ -18,61 +18,65 @@ from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
def find_language_pair(files):
for split in load_splits:
for filename in files:
def infer_language_pair(path, splits):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
for split in splits:
if parts[0] == split and parts[-1] == 'idx':
return parts[1].split('-')
src, dst = parts[1].split('-')
break
return src, dst
def split_exists(split, src, dst):
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
return os.path.exists(os.path.join(path, filename))
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def load_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if src is None and dst is None:
# find language pair automatically
src, dst = find_language_pair(os.listdir(path))
if not split_exists(load_splits[0], src, dst):
# try reversing src and dst
src, dst = dst, src
src, dst = infer_language_pair(path, load_splits)
def all_splits_exist(src, dst):
for split in load_splits:
if not split_exists(load_splits[0], src, dst):
raise ValueError('Data split not found: {}-{} ({})'.format(
src, dst, split))
dataset = load(path, load_splits, src, dst)
return dataset
def load(path, load_splits, src, dst):
"""Loads specified data splits (e.g. test, train or valid) from the path."""
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
src_dict = Dictionary.load(fmt_path('dict.{}.txt', src))
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in load_splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path('{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)),
IndexedInMemoryDataset(dst_path),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
......@@ -92,13 +96,11 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
def train_dataloader(self, split, num_workers=0, max_tokens=None,
max_positions=(1024, 1024), seed=None, epoch=1,
sample_without_replacement=0,
sort_by_source_size=False):
dataset = self.splits[split]
if split.startswith('train'):
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst,
......@@ -106,22 +108,23 @@ class LanguageDatasets(object):
sample=sample_without_replacement,
max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
elif split.startswith('valid'):
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, batch_size=1,
max_tokens=None, consider_dst_sizes=True,
max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False):
dataset = self.splits[split]
dst_dataset = dataset.dst if consider_dst_sizes else None
batch_sampler = list(batches_by_size(
dataset.src, batch_size, max_tokens, dst=dataset.dst,
dataset.src, dataset.dst, batch_size, max_tokens,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else:
batch_sampler = list(batches_by_size(
dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
)
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def skip_group_enumerator(it, ngpus, offset=0):
......@@ -174,14 +177,15 @@ class LanguagePairDataset(object):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
'ntokens': ntokens,
'ntokens': sum(len(s['target']) for s in samples),
}
@staticmethod
......@@ -218,18 +222,14 @@ def _valid_size(src_size, dst_size, max_positions):
return True
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
def batches_by_size(src, dst, batch_size=None, max_tokens=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert isinstance(src, IndexedDataset)
assert dst is None or isinstance(dst, IndexedDataset)
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
sizes = src.sizes
indices = np.argsort(sizes, kind='mergesort')
if dst is not None:
sizes = np.maximum(sizes, dst.sizes)
indices = np.argsort(src.sizes, kind='mergesort')
batch = []
......@@ -238,7 +238,7 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
return False
if len(batch) == batch_size:
return True
if sizes[batch[0]] != sizes[next_idx]:
if src.sizes[batch[0]] != src.sizes[next_idx]:
return True
if num_tokens >= max_tokens:
return True
......@@ -247,21 +247,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size = 0
ignored = []
for idx in indices:
if not _valid_size(src.sizes[idx],
None if dst is None else dst.sizes[idx],
max_positions):
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception("Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], "none" if dst is None else dst.sizes[idx]))
raise Exception(
"Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch
batch = []
cur_max_size = 0
batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx])
cur_max_size = max(cur_max_size, src.sizes[idx], dst.sizes[idx])
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
......
......@@ -34,7 +34,7 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang)
dataset = data.load_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
......@@ -67,8 +67,8 @@ def main():
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=max_positions,
itr = dataset.eval_dataloader(
args.gen_subset, batch_size=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
......
......@@ -26,19 +26,17 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
# TODO: load only dictionaries
dataset = data.load_with_check(args.data, ['test'], args.source_lang, args.target_lang)
# Load dictionaries
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
args.source_lang, args.target_lang, _ = data.infer_language_pair(args.data, ['test'])
src_dict, dst_dict = data.load_dictionaries(args.data, args.source_lang, args.target_lang)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
models = utils.load_ensemble_for_inference(args.path, src_dict, dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(dst_dict)))
# Optimize ensemble for generation
for model in models:
......@@ -60,7 +58,7 @@ def main():
print('Type the input sentence and press return:')
for src_str in sys.stdin:
src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, dataset.src_dict, add_if_not_exist=False).long()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
if use_cuda:
src_tokens = src_tokens.cuda()
translations = translator.generate(Variable(src_tokens.view(1, -1)))
......@@ -74,7 +72,7 @@ def main():
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
dst_dict=dst_dict,
remove_bpe=args.remove_bpe)
print('A\t{}'.format(' '.join(map(str, alignment))))
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
......
......@@ -34,7 +34,6 @@ def main():
options.add_model_args(parser)
args = utils.parse_args_and_arch(parser)
print(args)
if args.no_progress_bar:
progress_bar.enabled = False
......@@ -45,11 +44,12 @@ def main():
torch.manual_seed(args.seed)
# Load dataset
dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
dataset = data.load_dataset(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in ['train', 'valid']:
......@@ -129,11 +129,10 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
torch.manual_seed(seed)
trainer.set_seed(seed)
itr = dataset.dataloader(
itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens,
seed=seed, epoch=epoch, max_positions=max_positions,
max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
......@@ -216,7 +215,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(
itr = dataset.eval_dataloader(
subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment