Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4735c2af
Commit
4735c2af
authored
Nov 08, 2019
by
Rémi Louf
Committed by
Julien Chaumond
Dec 09, 2019
Browse files
tweaks to the BeamSearch API
parent
ba089c78
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
57 deletions
+59
-57
transformers/generate/beam_search.py
transformers/generate/beam_search.py
+24
-39
transformers/tests/beam_search_tests.py
transformers/tests/beam_search_tests.py
+35
-18
No files found.
transformers/generate/beam_search.py
View file @
4735c2af
...
...
@@ -32,7 +32,7 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
class
BeamSearch
(
nn
.
Module
):
class
BeamSearch
(
object
):
def
__init__
(
self
,
model
,
...
...
@@ -45,12 +45,17 @@ class BeamSearch(nn.Module):
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
r
"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**bos_token_id**: int
Id that is used by the tokenizer to represent the beggining of a sentence.
**pad_token_id**: int
Id that is used by the tokenizer for padding.
**eos_token_id**: int
Id that is used by the tokenizer to represent the end of a sentence.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
...
...
@@ -68,7 +73,7 @@ class BeamSearch(nn.Module):
"""
super
(
BeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
device
=
device
self
.
device
=
next
(
model
.
parameters
()).
device
# only works if all parameters of the model are stored on a single GPU
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
...
...
@@ -86,10 +91,7 @@ class BeamSearch(nn.Module):
self
.
_init_beam_state
(
batch_size
)
def
__len__
(
self
):
try
:
return
self
.
growing_beams
.
size
(
1
)
except
NameError
:
return
0
return
self
.
growing_beams
.
size
(
1
)
def
_init_beam_state
(
self
,
batch_size
):
""" (re-)Initialize the state of the beams. """
...
...
@@ -120,7 +122,7 @@ class BeamSearch(nn.Module):
self
.
_step
=
0
self
.
is_done
=
False
def
forward
(
self
,
encoder_input_ids
,
**
model_kwargs
):
def
__call__
(
self
,
encoder_input_ids
,
**
model_kwargs
):
""" Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
...
...
@@ -158,28 +160,17 @@ class BeamSearch(nn.Module):
kwargs_encoder
[
"attention_mask"
],
self
.
beam_size
,
dim
=
0
)
# grow the beam
by generating sequences in an autoregressive wa
y
# grow the beam
iterativel
y
batch_size
,
block_size
=
encoder_input_ids
.
size
()
self
.
_init_beam_state
(
batch_size
)
for
step
in
range
(
self
.
max_length
):
# Add padding tokens
decoder_input
=
torch
.
full
(
(
self
.
growing_beams
.
size
(
0
),
block_size
),
self
.
pad_token_id
,
dtype
=
torch
.
long
,
device
=
self
.
growing_beams
.
device
,
)
decoder_input
[:,
:
self
.
growing_beams
.
size
(
1
)]
=
self
.
growing_beams
# compute decoder_attention_mask
decoder_mask
=
torch
.
ones_like
(
decoder_input
)
idx_pad_tokens
=
decoder_input
==
self
.
pad_token_id
decoder_mask
[
idx_pad_tokens
]
=
0
kwargs_decoder
[
"attention_mask"
]
=
decoder_mask
decoder_input
=
fit_to_block_size
(
self
.
growing_beams
,
block_size
,
self
.
pad_token_id
)
kwargs_decoder
[
"attention_mask"
]
=
build_mask
(
decoder_input
)
outputs
=
self
.
model
.
decoder
(
decoder_input
,
**
kwargs_decoder
)
last_token_scores
=
outputs
[
0
][:,
-
1
,
:].
squeeze
(
1
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
last_token_scores
,
dim
=
0
)
next_token_scores
=
outputs
[
0
][:,
-
1
,
:].
squeeze
(
1
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
next_token_scores
,
dim
=
0
)
surviving_beams_rows
=
self
.
grow
(
log_probabilities
)
if
self
.
is_done
:
break
...
...
@@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right.
"""
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
return
torch
.
cat
(
(
sequence
,
torch
.
tensor
([
pad_token_id
]
*
(
block_size
-
len
(
sequence
)))),
dim
=
0
)
def
build_lm_labels
(
sequence
,
pad_token_id
):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded
=
sequence
.
clone
()
padded
[
padded
==
pad_token_id
]
=
-
1
return
padded
padded_sequence
=
torch
.
full
(
(
sequence
.
size
(
0
),
block_size
),
pad_token_id
,
dtype
=
torch
.
long
,
device
=
sequence
.
device
,
)
padded_sequence
[:,
:
sequence
.
size
(
1
)]
=
sequence
return
sequence
def
build_mask
(
sequence
,
pad_token_id
):
...
...
transformers/tests/beam_search_tests.py
View file @
4735c2af
from
collections
import
namedtuple
import
unittest
import
pytest
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers.generate
import
BeamSearch
from
transformers
import
PreTrainedEncoderDecoder
StubTokenizer
=
namedtuple
(
"Tokenizer"
,
[
"bos_token_id"
,
"eos_token_id"
,
"pad_token_id"
])
StubTransformer
=
namedtuple
(
"Transformer"
,
[
"encoder"
,
"decoder"
])
class
StubTransformer
(
nn
.
Module
):
def
__init__
(
self
):
self
.
encoder
=
None
self
.
decoder
=
None
self
.
_parameters
=
{
"dumy"
:
torch
.
tensor
([
1
])}
def
forward
(
self
):
pass
class
BeamSearchtest
(
unittest
.
TestCase
):
...
...
@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
class will break the integration with the beam search.
"""
model
=
PreTrainedEncoderDecoder
(
"encoder"
,
"decoder"
)
tokenizer
=
StubTokenizer
(
0
,
1
,
2
)
model
=
StubTransformer
()
try
:
_
=
BeamSearch
(
model
=
model
,
tokenizer
=
tokenizer
,
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
1
,
beam_size
=
1
,
min_length
=
1
,
...
...
@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
min_length
=
5
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
eos_idx
,
pad_token_id
=
2
),
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
eos_idx
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
5
,
...
...
@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities
[
0
,
eos_idx
]
=
score_distribution
[
0
]
for
idx
,
score
in
zip
(
non_eos_idxs
,
score_distribution
[
1
:]):
log_probabilities
[
0
,
idx
]
=
score
pytest
.
set_trace
()
for
step
in
range
(
1
,
min_length
+
2
):
log_probabilities
[
0
,
eos_idx
]
=
score_distribution
[
0
]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
# The top beam (most likely) always ends with 4 until we reach min_length.
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
if
step
<
min_length
:
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
(),
np
.
repeat
(
np
.
array
([[
0
]
+
[
4
]
*
step
]),
2
,
axis
=
0
),
beam
.
growing_beams
.
numpy
()[
0
,
:],
np
.
array
([
0
]
+
[
4
]
*
step
)
)
elif
step
==
min_length
:
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([]))
...
...
@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size
=
10
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
),
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
...
...
@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size
=
10
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
),
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
...
...
@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities
[::
beam_size
,
idx
]
=
score
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
if
step
<
7
:
self
.
assertFalse
(
...
...
@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
np
.
array
([
-
1e20
]
*
vocab_size
,
dtype
=
"float32"
),
)
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
def
test_beam_search_example_for_one_step
(
self
):
""" We test that the predictions for one step of growth are correct. """
batch_size
=
2
...
...
@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size
=
5
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
),
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment