Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
73a87327
"src/vscode:/vscode.git/clone" did not exist on "919e27d35751d9b87f7fe41bef60c5d5f44e53fe"
Commit
73a87327
authored
Apr 07, 2018
by
Myle Ott
Browse files
Fix batching during generation
parent
47b3b81c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
12 deletions
+24
-12
fairseq/data.py
fairseq/data.py
+24
-12
No files found.
fairseq/data.py
View file @
73a87327
...
@@ -143,7 +143,10 @@ class LanguageDatasets(object):
...
@@ -143,7 +143,10 @@ class LanguageDatasets(object):
with
numpy_seed
(
seed
):
with
numpy_seed
(
seed
):
batches
=
uneven_batches_by_size
(
batches
=
uneven_batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
max_positions
=
max_positions
)
max_sentences
=
max_sentences
,
max_positions
=
max_positions
,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple
=
8
,
)
frozen_batches
=
tuple
(
batches
)
# freeze
frozen_batches
=
tuple
(
batches
)
# freeze
def
dataloader
(
b
):
def
dataloader
(
b
):
...
@@ -310,8 +313,10 @@ def _valid_size(src_size, dst_size, max_positions):
...
@@ -310,8 +313,10 @@ def _valid_size(src_size, dst_size, max_positions):
def
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
def
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
=
False
,
allow_different_src_lens
=
False
):
ignore_invalid_inputs
=
False
,
allow_different_src_lens
=
False
,
required_batch_size_multiple
=
1
):
batch
=
[]
batch
=
[]
mult
=
required_batch_size_multiple
def
yield_batch
(
next_idx
,
num_tokens
):
def
yield_batch
(
next_idx
,
num_tokens
):
if
len
(
batch
)
==
0
:
if
len
(
batch
)
==
0
:
...
@@ -326,6 +331,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
...
@@ -326,6 +331,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
return
False
return
False
sample_len
=
0
sample_len
=
0
sample_lens
=
[]
ignored
=
[]
ignored
=
[]
for
idx
in
map
(
int
,
indices
):
for
idx
in
map
(
int
,
indices
):
src_size
=
src
.
sizes
[
idx
]
src_size
=
src
.
sizes
[
idx
]
...
@@ -339,15 +345,15 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
...
@@ -339,15 +345,15 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
" Skip this example with --skip-invalid-size-inputs-valid-test"
" Skip this example with --skip-invalid-size-inputs-valid-test"
).
format
(
idx
,
src_size
,
dst_size
,
max_positions
))
).
format
(
idx
,
src_size
,
dst_size
,
max_positions
))
sample_len
=
max
(
sample_len
,
src_size
,
dst_size
)
sample_lens
.
append
(
max
(
src_size
,
dst_size
))
sample_len
=
max
(
sample_len
,
sample_lens
[
-
1
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
while
yield_batch
(
idx
,
num_tokens
):
if
yield_batch
(
idx
,
num_tokens
):
mod8_len
=
max
(
8
*
(
len
(
batch
)
//
8
),
len
(
batch
)
%
8
)
mod8_len
=
max
(
mult
*
(
len
(
batch
)
//
mult
),
len
(
batch
)
%
mult
)
yield
batch
[:
mod8_len
]
yield
batch
[:
mod8_len
]
batch
=
batch
[
mod8_len
:]
batch
=
batch
[
mod8_len
:]
sample_len
=
max
([
max
(
src
.
sizes
[
id
],
dst
.
sizes
[
id
])
for
id
in
batch
])
if
len
(
batch
)
>
0
else
0
sample_lens
=
sample_lens
[
mod8_len
:]
sample_len
=
max
(
sample_len
,
src_size
,
dst_size
)
sample_len
=
max
(
sample_lens
)
if
len
(
sample_lens
)
>
0
else
0
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
batch
.
append
(
idx
)
batch
.
append
(
idx
)
...
@@ -361,7 +367,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
...
@@ -361,7 +367,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
,
descending
=
False
):
descending
=
False
,
required_batch_size_multiple
=
1
):
"""Returns batches of indices sorted by size. Sequences with different
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
(
dst
is
None
or
isinstance
(
dst
,
IndexedDataset
))
assert
isinstance
(
src
,
IndexedDataset
)
and
(
dst
is
None
or
isinstance
(
dst
,
IndexedDataset
))
...
@@ -374,10 +380,14 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
...
@@ -374,10 +380,14 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices
=
np
.
flip
(
indices
,
0
)
indices
=
np
.
flip
(
indices
,
0
)
return
list
(
_make_batches
(
return
list
(
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
,
allow_different_src_lens
=
False
))
ignore_invalid_inputs
,
allow_different_src_lens
=
False
,
required_batch_size_multiple
=
required_batch_size_multiple
,
))
def
uneven_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
)):
def
uneven_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
required_batch_size_multiple
=
1
):
"""Returns batches of indices bucketed by size. Batches may contain
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
@@ -394,7 +404,9 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_po
...
@@ -394,7 +404,9 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, max_po
batches
=
list
(
_make_batches
(
batches
=
list
(
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
=
True
,
allow_different_src_lens
=
True
))
ignore_invalid_inputs
=
True
,
allow_different_src_lens
=
True
,
required_batch_size_multiple
=
required_batch_size_multiple
,
))
return
batches
return
batches
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment