Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9362eb4a
Commit
9362eb4a
authored
Mar 05, 2020
by
patrickvonplaten
Browse files
refactored beam search according to torch implementation
parent
c8035e11
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
26 deletions
+47
-26
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+47
-26
No files found.
src/transformers/modeling_tf_utils.py
View file @
9362eb4a
...
...
@@ -557,6 +557,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
else
:
assert
len
(
shape_list
(
input_ids
))
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if
do_sample
is
False
:
if
num_beams
==
1
:
# no_beam_search greedy generation conditions
...
...
@@ -580,13 +581,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
cur_len
=
shape_list
(
input_ids
)[
1
]
vocab_size
=
self
.
config
.
vocab_size
if
num_return_sequences
!=
1
and
do_sample
:
# Expand input to num return sequences
input_ids
=
tf
.
broadcast_to
(
tf
.
expand_dims
(
input_ids
,
1
),
(
batch_size
,
num_return_sequences
,
cur_len
))
# set effective batch size and effective batch multiplier according to do_sample
if
do_sample
:
effective_batch_size
=
batch_size
*
num_return_sequences
input_ids
=
tf
.
reshape
(
input_ids
,
(
effective_batch_size
,
cur_len
))
effective_batch_mult
=
num_return_sequences
else
:
effective_batch_size
=
batch_size
effective_batch_mult
=
1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if
num_return_sequences
>
1
or
num_beams
>
1
:
input_ids_len
=
shape_list
(
input_ids
)[
-
1
]
input_ids
=
tf
.
broadcast_to
(
tf
.
expand_dims
(
input_ids
,
1
),
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
)
input_ids
=
tf
.
reshape
(
input_ids
,
(
effective_batch_size
*
num_beams
,
input_ids_len
)
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
num_beams
>
1
:
output
=
self
.
_generate_beam_search
(
...
...
@@ -701,12 +712,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# unfinished_sents is set to zero if eos in sentence
unfinished_sents
-=
is_sents_unfinished_and_token_to_add_is_eos
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
tf
.
math
.
reduce_max
(
unfinished_sents
)
==
0
:
break
cur_len
=
cur_len
+
1
# if there are different sentences lengths in the batch, some batches have to be padded
min_sent_length
=
tf
.
math
.
reduce_min
(
sent_lengths
)
max_sent_length
=
tf
.
math
.
reduce_max
(
sent_lengths
)
...
...
@@ -750,10 +761,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids
=
tf
.
broadcast_to
(
tf
.
expand_dims
(
input_ids
,
1
),
(
batch_size
,
num_beams
,
cur_len
))
input_ids
=
tf
.
reshape
(
input_ids
,
(
batch_size
*
num_beams
,
cur_len
))
# (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
False
)
for
_
in
range
(
batch_size
)
...
...
@@ -768,7 +775,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_scores
=
tf
.
zeros
((
batch_size
,
num_beams
),
dtype
=
tf
.
float32
)
beam_scores
=
tf
.
reshape
(
beam_scores
,
(
batch_size
*
num_beams
,))
# cache compute states
past
=
None
...
...
@@ -813,6 +819,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
# (batch_size, 2 * num_beams)
# Compute next scores
next_scores
=
tf
.
gather
(
_scores
,
next_tokens
,
batch_dims
=
1
)
# (batch_size, 2 * num_beams)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores_indices
=
tf
.
argsort
(
next_scores
,
direction
=
"DESCENDING"
,
axis
=
1
)
next_scores
=
tf
.
gather
(
next_scores
,
next_scores_indices
,
batch_dims
=
1
)
# (batch_size, num_beams * 2)
next_tokens
=
tf
.
gather
(
next_tokens
,
next_scores_indices
,
batch_dims
=
1
)
# (batch_size, num_beams * 2)
else
:
# do greedy beam search
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
@@ -826,6 +837,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_scores
=
tf
.
reshape
(
next_scores
,
(
batch_size
,
num_beams
*
vocab_size
)
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
tf
.
math
.
top_k
(
next_scores
,
2
*
num_beams
,
sorted
=
True
)
assert
shape_list
(
next_scores
)
==
shape_list
(
next_tokens
)
==
[
batch_size
,
2
*
num_beams
]
...
...
@@ -861,14 +873,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_id
=
idx
//
vocab_size
token_id
=
idx
%
vocab_size
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
# add to generated hypotheses if end of sentence or last iteration
if
eos_token_ids
is
not
None
and
token_id
.
numpy
()
in
eos_token_ids
:
generated_hyps
[
batch_idx
].
add
(
tf
.
identity
(
input_ids
[
batch_idx
*
num_beams
+
beam_id
,
:
cur_len
]),
score
.
numpy
()
)
generated_hyps
[
batch_idx
].
add
(
tf
.
identity
(
input_ids
[
effective_beam_id
]),
score
.
numpy
())
else
:
# add next predicted token if it is not eos_token
next_sent_beam
.
append
((
score
,
token_id
,
batch_idx
*
num_beams
+
beam_id
))
next_sent_beam
.
append
((
score
,
token_id
,
effective_
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
...
...
@@ -893,24 +904,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
past
:
past
=
self
.
_reorder_cache
(
past
,
beam_idx
)
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
# update current length
cur_len
=
cur_len
+
1
# finalize all open beam hypotheses and end to generated hypotheses
for
batch_idx
in
range
(
batch_size
):
# Add all open beam hypothesis to generated_hyps
if
not
done
[
batch_idx
]:
for
idx
,
score
in
zip
(
next_tokens
[
batch_idx
],
next_scores
[
batch_idx
]):
if
done
[
batch_idx
]:
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if
eos_token_ids
is
not
None
and
all
(
(
token_id
%
vocab_size
).
numpy
().
item
()
not
in
eos_token_ids
for
token_id
in
next_tokens
[
batch_idx
]
):
assert
tf
.
reduce_all
(
next_scores
[
batch_idx
,
:
num_beams
]
==
tf
.
reshape
(
beam_scores
,
(
batch_size
,
num_beams
))[
batch_idx
]
),
"If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}"
.
format
(
next_scores
[:,
:
num_beams
][
batch_idx
],
tf
.
reshape
(
beam_scores
,
(
batch_size
,
num_beams
))[
batch_idx
]
)
# get beam and token ID
s
beam_id
=
idx
//
vocab_size
token_id
=
idx
%
vocab_size
generated_hyps
[
batch
_id
x
].
add
(
tf
.
identity
(
input_ids
[
batch_idx
*
num_beams
+
beam_id
,
:
cur_len
]),
score
.
numpy
()
)
# need to add best num_beams hypotheses to generated hyp
s
for
beam_id
in
range
(
num_beams
):
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
final_score
=
beam_scores
[
effective_beam
_id
].
numpy
().
item
()
final_tokens
=
input_ids
[
effective_beam_id
]
generated_hyps
[
batch_idx
].
add
(
final_tokens
,
final_score
)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size
=
batch_size
if
do_sample
else
batch_size
*
num_return_sequences
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment