Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eb5bdcdf
Unverified
Commit
eb5bdcdf
authored
Apr 12, 2022
by
Joao Gante
Committed by
GitHub
Apr 12, 2022
Browse files
TF generate: handle case without cache in beam search (#16704)
parent
9c9db751
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
10 deletions
+44
-10
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+44
-10
No files found.
src/transformers/generation_tf_utils.py
View file @
eb5bdcdf
...
@@ -2514,6 +2514,7 @@ class TFGenerationMixin:
...
@@ -2514,6 +2514,7 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
# 3. init tensors to use for "xla-compileable" generate function
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
input_ids_length
=
cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences
=
tf
.
TensorArray
(
sequences
=
tf
.
TensorArray
(
...
@@ -2568,7 +2569,14 @@ class TFGenerationMixin:
...
@@ -2568,7 +2569,14 @@ class TFGenerationMixin:
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define stop-condition and auto-regressive function
# define stop-condition and auto-regressive function
def
beam_search_cond_fn
(
def
beam_search_cond_fn
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
):
):
"""
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
...
@@ -2597,7 +2605,7 @@ class TFGenerationMixin:
...
@@ -2597,7 +2605,7 @@ class TFGenerationMixin:
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
model_kwargs
,
input_ids_length
=
1
,
input_ids_length
,
intermediary_running_sequences
=
None
,
intermediary_running_sequences
=
None
,
):
):
"""
"""
...
@@ -2754,9 +2762,11 @@ class TFGenerationMixin:
...
@@ -2754,9 +2762,11 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input
# if we don't cache past key values we need the whole input
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
input_ids_length
=
cur_len
+
1
next_
input_ids_length
=
cur_len
+
1
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
else
:
next_input_ids_length
=
1
# 9. Prepare the `tf.TensorArray` for the next iteration
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences
=
sequences
.
unstack
(
tf
.
transpose
(
next_sequences_seq_last
,
perm
=
[
2
,
0
,
1
]))
next_sequences
=
sequences
.
unstack
(
tf
.
transpose
(
next_sequences_seq_last
,
perm
=
[
2
,
0
,
1
]))
...
@@ -2772,6 +2782,7 @@ class TFGenerationMixin:
...
@@ -2772,6 +2782,7 @@ class TFGenerationMixin:
next_scores
,
next_scores
,
next_is_sent_finished
,
next_is_sent_finished
,
next_model_kwargs
,
next_model_kwargs
,
next_input_ids_length
,
)
)
# 5. run generation
# 5. run generation
...
@@ -2780,8 +2791,7 @@ class TFGenerationMixin:
...
@@ -2780,8 +2791,7 @@ class TFGenerationMixin:
beam_search_body_fn
,
intermediary_running_sequences
=
intermediary_running_sequences
beam_search_body_fn
,
intermediary_running_sequences
=
intermediary_running_sequences
)
)
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past` (if active)
beam_search_body_fn_first_iter
=
partial
(
beam_search_body_fn
,
input_ids_length
=
cur_len
)
(
(
cur_len
,
cur_len
,
running_sequences
,
running_sequences
,
...
@@ -2790,20 +2800,44 @@ class TFGenerationMixin:
...
@@ -2790,20 +2800,44 @@ class TFGenerationMixin:
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
model_kwargs
,
)
=
beam_search_body_fn_first_iter
(
input_ids_length
,
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
)
=
beam_search_body_fn
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
)
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
# NOT yield EOS token though)
if
beam_search_cond_fn
(
if
beam_search_cond_fn
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
):
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
=
tf
.
while_loop
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
beam_search_cond_fn
,
beam_search_cond_fn
,
beam_search_body_fn
,
beam_search_body_fn
,
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
),
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment