Commit 20dfba73 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed numpy based size filtering (#854)

Summary:
This bug got introduced in my [commit](https://github.com/fairinternal/fairseq-py/commit/9624f9651478bcb88022decf7e1b0685b410133b) for fast numpy based size filtering.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/854

Differential Revision: D17150350

fbshipit-source-id: cb564119543e116d6a17784d1c22e9bce7059a0c
parent 8d4588b1
...@@ -171,11 +171,11 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): ...@@ -171,11 +171,11 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
""" """
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray): if hasattr(dataset, 'sizes') and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes > max_positions].tolist() ignored = indices[dataset.sizes[indices] > max_positions].tolist()
indices = indices[dataset.sizes <= max_positions] indices = indices[dataset.sizes[indices] <= max_positions]
elif hasattr(dataset, 'sizes') and isinstance(dataset.sizes, list) and len(dataset.sizes) == 1: elif hasattr(dataset, 'sizes') and isinstance(dataset.sizes, list) and len(dataset.sizes) == 1:
ignored = indices[dataset.sizes[0] > max_positions].tolist() ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
indices = indices[dataset.sizes[0] <= max_positions] indices = indices[dataset.sizes[0][indices] <= max_positions]
else: else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment