Commit 73a87327 authored by Myle Ott's avatar Myle Ott
Browse files

Fix batching during generation

parent 47b3b81c
......@@ -143,7 +143,10 @@ class LanguageDatasets(object):
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions)
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
......@@ -310,8 +313,10 @@ def _valid_size(src_size, dst_size, max_positions):
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False):
ignore_invalid_inputs=False, allow_different_src_lens=False,
required_batch_size_multiple=1):
batch = []
mult = required_batch_size_multiple
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
......@@ -326,6 +331,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in map(int, indices):
src_size = src.sizes[idx]
......@@ -339,15 +345,15 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions))
sample_len = max(sample_len, src_size, dst_size)
sample_lens.append(max(src_size, dst_size))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
while yield_batch(idx, num_tokens):
mod8_len = max(8 * (len(batch) // 8), len(batch) % 8)
if yield_batch(idx, num_tokens):
mod8_len = max(mult * (len(batch) // mult), len(batch) % mult)
yield batch[:mod8_len]
batch = batch[mod8_len:]
sample_len = max([max(src.sizes[id], dst.sizes[id]) for id in batch]) if len(batch) > 0 else 0
sample_len = max(sample_len, src_size, dst_size)
num_tokens = (len(batch) + 1) * sample_len
sample_lens = sample_lens[mod8_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
......@@ -361,7 +367,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False):
descending=False, required_batch_size_multiple=1):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
......@@ -374,10 +380,14 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False))
ignore_invalid_inputs, allow_different_src_lens=False,
required_batch_size_multiple=required_batch_size_multiple,
))
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_positions=(1024, 1024)):
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024),
required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
......@@ -394,7 +404,9 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_po
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True))
ignore_invalid_inputs=True, allow_different_src_lens=True,
required_batch_size_multiple=required_batch_size_multiple,
))
return batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment