Commit 108f94bc authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

use numpy function for filter by size when possible (#845)

Summary:
For general Masked language modeling use-case, this is much faster, (`3 minutes vs 1 sec`).

Let me know what you think about it myleott, if you don't like all the special case checking, we can think of reorganizing the dataset APIs to always have `sizes` as property calculated in `__init__`.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/845

Reviewed By: myleott

Differential Revision: D16993769

Pulled By: myleott

fbshipit-source-id: 161bba62af2965190c07c47e838ee967cb886e88
parent d2410c42
......@@ -124,18 +124,7 @@ def collect_filtered(function, iterable, filtered):
filtered.append(el)
def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
"""
Filter indices based on their size.
Args:
indices (List[int]): ordered list of dataset indices
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions
......@@ -158,25 +147,55 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
# For MultiCorpusSampledDataset, will generalize it later
if not isinstance(size_fn(idx), Iterable):
return all(size_fn(idx) <= b for b in max_positions)
return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions))
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)
)
ignored = []
itr = collect_filtered(check_size, indices, ignored)
indices = np.fromiter(itr, dtype=np.int64, count=-1)
return indices, ignored
for idx in itr:
if len(ignored) > 0 and raise_exception:
raise Exception((
'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test'
).format(ignored[0], size_fn(ignored[0]), max_positions))
yield idx
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
"""
Filter indices based on their size.
Args:
indices (List[int]): ordered list of dataset indices
dataset (FairseqDataset): fairseq dataset instance
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
if isinstance(max_positions, float) or isinstance(max_positions, int):
if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes > max_positions].tolist()
indices = indices[dataset.sizes <= max_positions]
elif (
hasattr(dataset, 'sizes') and
isinstance(dataset.sizes, list) and
len(dataset.sizes) == 1
):
ignored = indices[dataset.sizes[0] > max_positions].tolist()
indices = indices[dataset.sizes[0] <= max_positions]
else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
if len(ignored) > 0 and raise_exception:
raise Exception((
'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test'
).format(ignored[0], dataset.size(ignored[0]), max_positions))
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), max_positions, ignored[:10]))
return indices
def batch_by_size(
......
......@@ -133,9 +133,8 @@ class FairseqTask(object):
# filter examples that are too large
if max_positions is not None:
indices = data_utils.filter_by_size(
indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
indices, dataset, max_positions, raise_exception=(not ignore_invalid_inputs),
)
indices = np.fromiter(indices, dtype=np.int64, count=-1)
# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment