Commit 7d44181d authored by Myle Ott's avatar Myle Ott
Browse files

Loop over evaluation dataloader in descending order

parent f442f896
......@@ -113,12 +113,14 @@ class LanguageDatasets(object):
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False):
skip_invalid_size_inputs_valid_test=False,
descending=False):
dataset = self.splits[split]
batch_sampler = list(batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending))
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
......@@ -264,7 +266,8 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False):
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
......@@ -273,6 +276,8 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return _make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)
......
......@@ -222,7 +222,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
itr = dataset.eval_dataloader(
subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment