Commit 5d9392df authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Better handling of unspecified max_tokens and max_sentences (fixes #1427)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/928

Differential Revision: D18691525

Pulled By: myleott

fbshipit-source-id: e787c17434d4cb0c4621e9858e0ebec4f9951630
parent cb6c67bc
...@@ -221,8 +221,8 @@ def batch_by_size( ...@@ -221,8 +221,8 @@ def batch_by_size(
'or `python setup.py build_ext --inplace`' 'or `python setup.py build_ext --inplace`'
) )
max_tokens = max_tokens if max_tokens is not None else sys.maxsize max_tokens = max_tokens if max_tokens is not None else -1
max_sentences = max_sentences if max_sentences is not None else sys.maxsize max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple bsz_mult = required_batch_size_multiple
if isinstance(indices, types.GeneratorType): if isinstance(indices, types.GeneratorType):
......
...@@ -16,9 +16,9 @@ ctypedef np.int64_t DTYPE_t ...@@ -16,9 +16,9 @@ ctypedef np.int64_t DTYPE_t
cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
if len(batch) == 0: if len(batch) == 0:
return 0 return 0
if len(batch) == max_sentences: if max_sentences > 0 and len(batch) == max_sentences:
return 1 return 1
if num_tokens > max_tokens: if max_tokens > 0 and num_tokens > max_tokens:
return 1 return 1
return 0 return 0
...@@ -47,7 +47,7 @@ cpdef list batch_by_size_fast( ...@@ -47,7 +47,7 @@ cpdef list batch_by_size_fast(
sample_lens.append(num_tokens) sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens) sample_len = max(sample_len, num_tokens)
assert sample_len <= max_tokens, ( assert max_tokens <= 0 or sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens " "sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens) "limit of {}!".format(idx, sample_len, max_tokens)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment