Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7a89a3e4
Commit
7a89a3e4
authored
Mar 04, 2020
by
Patrick von Platen
Browse files
correct beam search sampling
parent
c4c4c999
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
15 deletions
+21
-15
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+21
-15
No files found.
src/transformers/modeling_tf_utils.py
View file @
7a89a3e4
...
...
@@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
]
# scores for each sentence in the beam
if
do_sample
is
False
:
beam_scores_begin
=
tf
.
zeros
((
batch_size
,
1
),
dtype
=
tf
.
float32
)
beam_scores_end
=
tf
.
zeros
((
batch_size
,
num_beams
-
1
),
dtype
=
tf
.
float32
)
*
1e-9
beam_scores
=
tf
.
reshape
(
tf
.
concat
([
beam_scores_begin
,
beam_scores_end
],
-
1
),
(
batch_size
*
num_beams
,))
beam_scores
=
tf
.
concat
([
beam_scores_begin
,
beam_scores_end
],
-
1
)
else
:
beam_scores
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
float32
)
beam_scores
=
tf
.
reshape
(
beam_scores
,
(
batch_size
*
num_beams
,))
# cache compute states
past
=
None
...
...
@@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
_scores
=
scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
vocab_size
)
)
# (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
next_token_logit
s
=
tf_top_k_top_p_filtering
(
next_token_logit
s
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
_score
s
=
tf_top_k_top_p_filtering
(
_score
s
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
_scores
=
tf
.
reshape
(
_scores
,
(
batch_size
,
num_beams
*
vocab_size
))
next_tokens
=
tf
.
random
.
categorical
(
next_token_logit
s
,
dtype
=
tf
.
int32
,
num_samples
=
2
)
# (batch_size * num_beams
, vocab_size
)
_score
s
,
dtype
=
tf
.
int32
,
num_samples
=
2
*
num_beams
)
# (batch_size
, 2
* num_beams)
# Compute next scores
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
_scores
=
tf
.
gather
(
scores
,
next_tokens
,
batch_dims
=
1
)
# (batch_size * num_beams, 2)
next_scores
=
_scores
+
tf
.
broadcast_to
(
beam_scores
[:,
None
],
(
batch_size
*
num_beams
,
2
)
)
# (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_tokens
=
tf
.
reshape
(
next_tokens
,
(
batch_size
,
2
*
num_beams
))
# (batch_size, 2 * num_beams)
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
2
*
num_beams
))
# (batch_size, 2 * num_beams)
next_scores
=
tf
.
gather
(
_scores
,
next_tokens
,
batch_dims
=
1
)
# (batch_size, 2 * num_beams)
else
:
# do greedy beam search
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment