"src/vscode:/vscode.git/clone" did not exist on "919e27d35751d9b87f7fe41bef60c5d5f44e53fe"
Commit 73a87327 authored by Myle Ott's avatar Myle Ott
Browse files

Fix batching during generation

parent 47b3b81c
...@@ -143,7 +143,10 @@ class LanguageDatasets(object): ...@@ -143,7 +143,10 @@ class LanguageDatasets(object):
with numpy_seed(seed): with numpy_seed(seed):
batches = uneven_batches_by_size( batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens, dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions) max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze frozen_batches = tuple(batches) # freeze
def dataloader(b): def dataloader(b):
...@@ -310,8 +313,10 @@ def _valid_size(src_size, dst_size, max_positions): ...@@ -310,8 +313,10 @@ def _valid_size(src_size, dst_size, max_positions):
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False): ignore_invalid_inputs=False, allow_different_src_lens=False,
required_batch_size_multiple=1):
batch = [] batch = []
mult = required_batch_size_multiple
def yield_batch(next_idx, num_tokens): def yield_batch(next_idx, num_tokens):
if len(batch) == 0: if len(batch) == 0:
...@@ -326,6 +331,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -326,6 +331,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
return False return False
sample_len = 0 sample_len = 0
sample_lens = []
ignored = [] ignored = []
for idx in map(int, indices): for idx in map(int, indices):
src_size = src.sizes[idx] src_size = src.sizes[idx]
...@@ -339,15 +345,15 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -339,15 +345,15 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
" Skip this example with --skip-invalid-size-inputs-valid-test" " Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions)) ).format(idx, src_size, dst_size, max_positions))
sample_len = max(sample_len, src_size, dst_size) sample_lens.append(max(src_size, dst_size))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len num_tokens = (len(batch) + 1) * sample_len
while yield_batch(idx, num_tokens): if yield_batch(idx, num_tokens):
mod8_len = max(8 * (len(batch) // 8), len(batch) % 8) mod8_len = max(mult * (len(batch) // mult), len(batch) % mult)
yield batch[:mod8_len] yield batch[:mod8_len]
batch = batch[mod8_len:] batch = batch[mod8_len:]
sample_len = max([max(src.sizes[id], dst.sizes[id]) for id in batch]) if len(batch) > 0 else 0 sample_lens = sample_lens[mod8_len:]
sample_len = max(sample_len, src_size, dst_size) sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
num_tokens = (len(batch) + 1) * sample_len
batch.append(idx) batch.append(idx)
...@@ -361,7 +367,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -361,7 +367,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def batches_by_size(src, dst, max_tokens=None, max_sentences=None, def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False, max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False): descending=False, required_batch_size_multiple=1):
"""Returns batches of indices sorted by size. Sequences with different """Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch.""" source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset)) assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
...@@ -374,10 +380,14 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -374,10 +380,14 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices = np.flip(indices, 0) indices = np.flip(indices, 0)
return list(_make_batches( return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)) ignore_invalid_inputs, allow_different_src_lens=False,
required_batch_size_multiple=required_batch_size_multiple,
))
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_positions=(1024, 1024)): def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024),
required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain """Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths.""" sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
...@@ -394,7 +404,9 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_po ...@@ -394,7 +404,9 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_po
batches = list(_make_batches( batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True)) ignore_invalid_inputs=True, allow_different_src_lens=True,
required_batch_size_multiple=required_batch_size_multiple,
))
return batches return batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment