Commit 37b9c235 authored by frankang's avatar frankang Committed by Facebook Github Bot
Browse files

Fix iteration bug in GroupedIterator. Correct sent size filter. (#455)

Summary:
Fix iterating from the beginning bug when initializing the GroupedIterator. (https://github.com/pytorch/fairseq/issues/441)
 Correct filter criterion for dict type sentence size. (https://github.com/pytorch/fairseq/issues/451)
Pull Request resolved: https://github.com/pytorch/fairseq/pull/455

Differential Revision: D13725646

Pulled By: myleott

fbshipit-source-id: e698fa6f9b45460f95a75c9e9976a3aa3b6aa523
parent d259ffa9
...@@ -92,8 +92,9 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -92,8 +92,9 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
assert isinstance(idx_size, dict) assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all( return all(
idx_size[key] <= max_positions[key] for key in intersect_keys all(a is None or b is None or a <= b
) for a, b in zip(idx_size[key], max_positions[key]))
for key in intersect_keys)
else: else:
return all(a is None or b is None or a <= b return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)) for a, b in zip(size_fn(idx), max_positions))
...@@ -155,6 +156,7 @@ def batch_by_size( ...@@ -155,6 +156,7 @@ def batch_by_size(
for idx in indices: for idx in indices:
sample_lens.append(num_tokens_fn(idx)) sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1]) sample_len = max(sample_len, sample_lens[-1])
assert sample_len <= max_tokens, f"sentence at index {idx} exceeds max_tokens limit!"
num_tokens = (len(batch) + 1) * sample_len num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens): if is_batch_full(num_tokens):
mod_len = max( mod_len = max(
......
...@@ -197,7 +197,7 @@ class GroupedIterator(object): ...@@ -197,7 +197,7 @@ class GroupedIterator(object):
def __init__(self, iterable, chunk_size): def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size))) self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.itr = iter(iterable) self.itr = iterable
self.chunk_size = chunk_size self.chunk_size = chunk_size
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment