Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
73a87327
Commit
73a87327
authored
Apr 07, 2018
by
Myle Ott
Browse files
Fix batching during generation
parent
47b3b81c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
12 deletions
+24
-12
fairseq/data.py
fairseq/data.py
+24
-12
No files found.
fairseq/data.py
View file @
73a87327
...
...
@@ -143,7 +143,10 @@ class LanguageDatasets(object):
with
numpy_seed
(
seed
):
batches
=
uneven_batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
max_positions
=
max_positions
)
max_sentences
=
max_sentences
,
max_positions
=
max_positions
,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple
=
8
,
)
frozen_batches
=
tuple
(
batches
)
# freeze
def
dataloader
(
b
):
...
...
@@ -310,8 +313,10 @@ def _valid_size(src_size, dst_size, max_positions):
def
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
=
False
,
allow_different_src_lens
=
False
):
ignore_invalid_inputs
=
False
,
allow_different_src_lens
=
False
,
required_batch_size_multiple
=
1
):
batch
=
[]
mult
=
required_batch_size_multiple
def
yield_batch
(
next_idx
,
num_tokens
):
if
len
(
batch
)
==
0
:
...
...
@@ -326,6 +331,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
return
False
sample_len
=
0
sample_lens
=
[]
ignored
=
[]
for
idx
in
map
(
int
,
indices
):
src_size
=
src
.
sizes
[
idx
]
...
...
@@ -339,15 +345,15 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
" Skip this example with --skip-invalid-size-inputs-valid-test"
).
format
(
idx
,
src_size
,
dst_size
,
max_positions
))
sample_len
=
max
(
sample_len
,
src_size
,
dst_size
)
sample_lens
.
append
(
max
(
src_size
,
dst_size
))
sample_len
=
max
(
sample_len
,
sample_lens
[
-
1
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
while
yield_batch
(
idx
,
num_tokens
):
mod8_len
=
max
(
8
*
(
len
(
batch
)
//
8
),
len
(
batch
)
%
8
)
if
yield_batch
(
idx
,
num_tokens
):
mod8_len
=
max
(
mult
*
(
len
(
batch
)
//
mult
),
len
(
batch
)
%
mult
)
yield
batch
[:
mod8_len
]
batch
=
batch
[
mod8_len
:]
sample_len
=
max
([
max
(
src
.
sizes
[
id
],
dst
.
sizes
[
id
])
for
id
in
batch
])
if
len
(
batch
)
>
0
else
0
sample_len
=
max
(
sample_len
,
src_size
,
dst_size
)
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
sample_lens
=
sample_lens
[
mod8_len
:]
sample_len
=
max
(
sample_lens
)
if
len
(
sample_lens
)
>
0
else
0
batch
.
append
(
idx
)
...
...
@@ -361,7 +367,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
,
descending
=
False
):
descending
=
False
,
required_batch_size_multiple
=
1
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
(
dst
is
None
or
isinstance
(
dst
,
IndexedDataset
))
...
...
@@ -374,10 +380,14 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices
=
np
.
flip
(
indices
,
0
)
return
list
(
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
,
allow_different_src_lens
=
False
))
ignore_invalid_inputs
,
allow_different_src_lens
=
False
,
required_batch_size_multiple
=
required_batch_size_multiple
,
))
def
uneven_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
)):
def
uneven_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
required_batch_size_multiple
=
1
):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
...
@@ -394,7 +404,9 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_po
batches
=
list
(
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
=
True
,
allow_different_src_lens
=
True
))
ignore_invalid_inputs
=
True
,
allow_different_src_lens
=
True
,
required_batch_size_multiple
=
required_batch_size_multiple
,
))
return
batches
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment