"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "ad8907d3738fbf4c80aa269954d1d8ba4f307530"
Commit 7d44181d authored by Myle Ott's avatar Myle Ott
Browse files

Loop over evaluation dataloader in descending order

parent f442f896
...@@ -113,12 +113,14 @@ class LanguageDatasets(object): ...@@ -113,12 +113,14 @@ class LanguageDatasets(object):
def eval_dataloader(self, split, num_workers=0, max_tokens=None, def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024), max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False): skip_invalid_size_inputs_valid_test=False,
descending=False):
dataset = self.splits[split] dataset = self.splits[split]
batch_sampler = list(batches_by_size( batch_sampler = list(batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences, dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test)) ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending))
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater, dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler) batch_sampler=batch_sampler)
...@@ -264,7 +266,8 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -264,7 +266,8 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def batches_by_size(src, dst, max_tokens=None, max_sentences=None, def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False): max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different """Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch.""" source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
...@@ -273,6 +276,8 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -273,6 +276,8 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
if max_sentences is None: if max_sentences is None:
max_sentences = float('Inf') max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort') indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return _make_batches( return _make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False) ignore_invalid_inputs, allow_different_src_lens=False)
......
...@@ -222,7 +222,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): ...@@ -222,7 +222,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter() loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment