"vscode:/vscode.git/clone" did not exist on "049082e013fb71d78f3abf487916f3de2b674908"
Commit 8f9dd964 authored by Myle Ott's avatar Myle Ott
Browse files

Improvements to data loader

parent 97d7fcb9
...@@ -18,61 +18,65 @@ from fairseq.dictionary import Dictionary ...@@ -18,61 +18,65 @@ from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, load_splits, src=None, dst=None): def infer_language_pair(path, splits):
"""Loads specified data splits (e.g., test, train or valid) from the """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
specified folder and check that files exist.""" src, dst = None, None
for filename in os.listdir(path):
def find_language_pair(files):
for split in load_splits:
for filename in files:
parts = filename.split('.') parts = filename.split('.')
for split in splits:
if parts[0] == split and parts[-1] == 'idx': if parts[0] == split and parts[-1] == 'idx':
return parts[1].split('-') src, dst = parts[1].split('-')
break
return src, dst
def split_exists(split, src, dst):
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
return os.path.exists(os.path.join(path, filename))
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def load_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if src is None and dst is None: if src is None and dst is None:
# find language pair automatically # find language pair automatically
src, dst = find_language_pair(os.listdir(path)) src, dst = infer_language_pair(path, load_splits)
if not split_exists(load_splits[0], src, dst):
# try reversing src and dst
src, dst = dst, src
def all_splits_exist(src, dst):
for split in load_splits: for split in load_splits:
if not split_exists(load_splits[0], src, dst): filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
raise ValueError('Data split not found: {}-{} ({})'.format( if not os.path.exists(os.path.join(path, filename)):
src, dst, split)) return False
return True
dataset = load(path, load_splits, src, dst)
return dataset
def load(path, load_splits, src, dst):
"""Loads specified data splits (e.g. test, train or valid) from the path."""
# infer langcode
if all_splits_exist(src, dst):
langcode = '{}-{}'.format(src, dst) langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
def fmt_path(fmt, *args): def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args)) return os.path.join(path, fmt.format(*args))
src_dict = Dictionary.load(fmt_path('dict.{}.txt', src))
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in load_splits: for split in load_splits:
for k in itertools.count(): for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '') prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src) src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path('{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path): if not IndexedInMemoryDataset.exists(src_path):
break break
dataset.splits[prefix] = LanguagePairDataset( dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path), IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)), IndexedInMemoryDataset(dst_path),
pad_idx=dataset.src_dict.pad(), pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(), eos_idx=dataset.src_dict.eos(),
) )
...@@ -92,13 +96,11 @@ class LanguageDatasets(object): ...@@ -92,13 +96,11 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos() assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk() assert self.src_dict.unk() == self.dst_dict.unk()
def dataloader(self, split, batch_size=1, num_workers=0, def train_dataloader(self, split, num_workers=0, max_tokens=None,
max_tokens=None, seed=None, epoch=1, max_positions=(1024, 1024), seed=None, epoch=1,
sample_without_replacement=0, max_positions=(1024, 1024), sample_without_replacement=0,
skip_invalid_size_inputs_valid_test=False,
sort_by_source_size=False): sort_by_source_size=False):
dataset = self.splits[split] dataset = self.splits[split]
if split.startswith('train'):
with numpy_seed(seed): with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size( batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst, dataset.src, dataset.dst,
...@@ -106,22 +108,23 @@ class LanguageDatasets(object): ...@@ -106,22 +108,23 @@ class LanguageDatasets(object):
sample=sample_without_replacement, sample=sample_without_replacement,
max_positions=max_positions, max_positions=max_positions,
sort_by_source_size=sort_by_source_size) sort_by_source_size=sort_by_source_size)
elif split.startswith('valid'): return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, batch_size=1,
max_tokens=None, consider_dst_sizes=True,
max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False):
dataset = self.splits[split]
dst_dataset = dataset.dst if consider_dst_sizes else None
batch_sampler = list(batches_by_size( batch_sampler = list(batches_by_size(
dataset.src, batch_size, max_tokens, dst=dataset.dst, dataset.src, dataset.dst, batch_size, max_tokens,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test)) ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else:
batch_sampler = list(batches_by_size(
dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, dataset, num_workers=num_workers, collate_fn=dataset.collater,
num_workers=num_workers, batch_sampler=batch_sampler)
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
)
def skip_group_enumerator(it, ngpus, offset=0): def skip_group_enumerator(it, ngpus, offset=0):
...@@ -174,14 +177,15 @@ class LanguagePairDataset(object): ...@@ -174,14 +177,15 @@ class LanguagePairDataset(object):
return LanguagePairDataset.collate_tokens( return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning) [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
ntokens = sum(len(s['target']) for s in samples)
return { return {
'id': torch.LongTensor([s['id'].item() for s in samples]), 'id': torch.LongTensor([s['id'].item() for s in samples]),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET, 'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True), move_eos_to_beginning=True),
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET), 'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE), 'ntokens': sum(len(s['target']) for s in samples),
'ntokens': ntokens,
} }
@staticmethod @staticmethod
...@@ -218,18 +222,14 @@ def _valid_size(src_size, dst_size, max_positions): ...@@ -218,18 +222,14 @@ def _valid_size(src_size, dst_size, max_positions):
return True return True
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, def batches_by_size(src, dst, batch_size=None, max_tokens=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False): max_positions=(1024, 1024), ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths """Returns batches of indices sorted by size. Sequences with different
are not allowed in the same batch.""" source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
assert dst is None or isinstance(dst, IndexedDataset)
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
sizes = src.sizes indices = np.argsort(src.sizes, kind='mergesort')
indices = np.argsort(sizes, kind='mergesort')
if dst is not None:
sizes = np.maximum(sizes, dst.sizes)
batch = [] batch = []
...@@ -238,7 +238,7 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, ...@@ -238,7 +238,7 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
return False return False
if len(batch) == batch_size: if len(batch) == batch_size:
return True return True
if sizes[batch[0]] != sizes[next_idx]: if src.sizes[batch[0]] != src.sizes[next_idx]:
return True return True
if num_tokens >= max_tokens: if num_tokens >= max_tokens:
return True return True
...@@ -247,21 +247,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, ...@@ -247,21 +247,20 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size = 0 cur_max_size = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
if not _valid_size(src.sizes[idx], if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
None if dst is None else dst.sizes[idx],
max_positions):
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception("Unable to handle input id {} of size {} / {}.".format( raise Exception(
idx, src.sizes[idx], "none" if dst is None else dst.sizes[idx])) "Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], dst.sizes[idx]))
if yield_batch(idx, cur_max_size * (len(batch) + 1)): if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch yield batch
batch = [] batch = []
cur_max_size = 0 cur_max_size = 0
batch.append(idx) batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx]) cur_max_size = max(cur_max_size, src.sizes[idx], dst.sizes[idx])
if len(ignored) > 0: if len(ignored) > 0:
print("Warning! {} samples are either too short or too long " print("Warning! {} samples are either too short or too long "
......
...@@ -34,7 +34,7 @@ def main(): ...@@ -34,7 +34,7 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang) dataset = data.load_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
# record inferred languages in args # record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst args.source_lang, args.target_lang = dataset.src, dataset.dst
...@@ -67,8 +67,8 @@ def main(): ...@@ -67,8 +67,8 @@ def main():
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models) max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, itr = dataset.eval_dataloader(
max_positions=max_positions, args.gen_subset, batch_size=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0 num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t: with progress_bar(itr, smoothing=0, leave=False) as t:
......
...@@ -26,19 +26,17 @@ def main(): ...@@ -26,19 +26,17 @@ def main():
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dictionaries
# TODO: load only dictionaries
dataset = data.load_with_check(args.data, ['test'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
# record inferred languages in args args.source_lang, args.target_lang, _ = data.infer_language_pair(args.data, ['test'])
args.source_lang, args.target_lang = dataset.src, dataset.dst src_dict, dst_dict = data.load_dictionaries(args.data, args.source_lang, args.target_lang)
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict) models = utils.load_ensemble_for_inference(args.path, src_dict, dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(args.target_lang, len(dst_dict)))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
...@@ -60,7 +58,7 @@ def main(): ...@@ -60,7 +58,7 @@ def main():
print('Type the input sentence and press return:') print('Type the input sentence and press return:')
for src_str in sys.stdin: for src_str in sys.stdin:
src_str = src_str.strip() src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, dataset.src_dict, add_if_not_exist=False).long() src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
if use_cuda: if use_cuda:
src_tokens = src_tokens.cuda() src_tokens = src_tokens.cuda()
translations = translator.generate(Variable(src_tokens.view(1, -1))) translations = translator.generate(Variable(src_tokens.view(1, -1)))
...@@ -74,7 +72,7 @@ def main(): ...@@ -74,7 +72,7 @@ def main():
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict, align_dict=align_dict,
dst_dict=dataset.dst_dict, dst_dict=dst_dict,
remove_bpe=args.remove_bpe) remove_bpe=args.remove_bpe)
print('A\t{}'.format(' '.join(map(str, alignment)))) print('A\t{}'.format(' '.join(map(str, alignment))))
print('H\t{}\t{}'.format(hypo['score'], hypo_str)) print('H\t{}\t{}'.format(hypo['score'], hypo_str))
......
...@@ -34,7 +34,6 @@ def main(): ...@@ -34,7 +34,6 @@ def main():
options.add_model_args(parser) options.add_model_args(parser)
args = utils.parse_args_and_arch(parser) args = utils.parse_args_and_arch(parser)
print(args)
if args.no_progress_bar: if args.no_progress_bar:
progress_bar.enabled = False progress_bar.enabled = False
...@@ -45,11 +44,12 @@ def main(): ...@@ -45,11 +44,12 @@ def main():
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# Load dataset # Load dataset
dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang) dataset = data.load_dataset(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints # record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst args.source_lang, args.target_lang = dataset.src, dataset.dst
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in ['train', 'valid']: for split in ['train', 'valid']:
...@@ -129,11 +129,10 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -129,11 +129,10 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
torch.manual_seed(seed) torch.manual_seed(seed)
trainer.set_seed(seed) trainer.set_seed(seed)
itr = dataset.dataloader( itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens,
seed=seed, epoch=epoch, max_positions=max_positions, max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement, sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
sort_by_source_size=(epoch <= args.curriculum)) sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter() loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch bsz_meter = AverageMeter() # sentences per batch
...@@ -216,7 +215,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): ...@@ -216,7 +215,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader( itr = dataset.eval_dataloader(
subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions, subset, batch_size=None, max_tokens=args.max_tokens, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter() loss_meter = AverageMeter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment