Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
eb5bdcdf
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c5037b459e117b9286c611092f38663f6cb763b0"
Unverified
Commit
eb5bdcdf
authored
Apr 12, 2022
by
Joao Gante
Committed by
GitHub
Apr 12, 2022
Browse files
TF generate: handle case without cache in beam search (#16704)
parent
9c9db751
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
10 deletions
+44
-10
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+44
-10
No files found.
src/transformers/generation_tf_utils.py
View file @
eb5bdcdf
...
@@ -2514,6 +2514,7 @@ class TFGenerationMixin:
...
@@ -2514,6 +2514,7 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
# 3. init tensors to use for "xla-compileable" generate function
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
input_ids_length
=
cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences
=
tf
.
TensorArray
(
sequences
=
tf
.
TensorArray
(
...
@@ -2568,7 +2569,14 @@ class TFGenerationMixin:
...
@@ -2568,7 +2569,14 @@ class TFGenerationMixin:
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define stop-condition and auto-regressive function
# define stop-condition and auto-regressive function
def
beam_search_cond_fn
(
def
beam_search_cond_fn
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
):
):
"""
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
...
@@ -2597,7 +2605,7 @@ class TFGenerationMixin:
...
@@ -2597,7 +2605,7 @@ class TFGenerationMixin:
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
model_kwargs
,
input_ids_length
=
1
,
input_ids_length
,
intermediary_running_sequences
=
None
,
intermediary_running_sequences
=
None
,
):
):
"""
"""
...
@@ -2754,9 +2762,11 @@ class TFGenerationMixin:
...
@@ -2754,9 +2762,11 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input
# if we don't cache past key values we need the whole input
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
if
model_kwargs
.
get
(
"past"
,
None
)
is
None
:
input_ids_length
=
cur_len
+
1
next_
input_ids_length
=
cur_len
+
1
# let's throw out `past` since we don't want `None` tensors
# let's throw out `past` since we don't want `None` tensors
model_kwargs
.
pop
(
"past"
,
None
)
model_kwargs
.
pop
(
"past"
,
None
)
else
:
next_input_ids_length
=
1
# 9. Prepare the `tf.TensorArray` for the next iteration
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences
=
sequences
.
unstack
(
tf
.
transpose
(
next_sequences_seq_last
,
perm
=
[
2
,
0
,
1
]))
next_sequences
=
sequences
.
unstack
(
tf
.
transpose
(
next_sequences_seq_last
,
perm
=
[
2
,
0
,
1
]))
...
@@ -2772,6 +2782,7 @@ class TFGenerationMixin:
...
@@ -2772,6 +2782,7 @@ class TFGenerationMixin:
next_scores
,
next_scores
,
next_is_sent_finished
,
next_is_sent_finished
,
next_model_kwargs
,
next_model_kwargs
,
next_input_ids_length
,
)
)
# 5. run generation
# 5. run generation
...
@@ -2780,8 +2791,7 @@ class TFGenerationMixin:
...
@@ -2780,8 +2791,7 @@ class TFGenerationMixin:
beam_search_body_fn
,
intermediary_running_sequences
=
intermediary_running_sequences
beam_search_body_fn
,
intermediary_running_sequences
=
intermediary_running_sequences
)
)
# 1st generation step has to be run before to initialize `past`
# 1st generation step has to be run before to initialize `past` (if active)
beam_search_body_fn_first_iter
=
partial
(
beam_search_body_fn
,
input_ids_length
=
cur_len
)
(
(
cur_len
,
cur_len
,
running_sequences
,
running_sequences
,
...
@@ -2790,20 +2800,44 @@ class TFGenerationMixin:
...
@@ -2790,20 +2800,44 @@ class TFGenerationMixin:
scores
,
scores
,
is_sent_finished
,
is_sent_finished
,
model_kwargs
,
model_kwargs
,
)
=
beam_search_body_fn_first_iter
(
input_ids_length
,
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
)
=
beam_search_body_fn
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
)
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
# NOT yield EOS token though)
if
beam_search_cond_fn
(
if
beam_search_cond_fn
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
):
):
maximum_iterations
=
max_length
-
cur_len
maximum_iterations
=
max_length
-
cur_len
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
=
tf
.
while_loop
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
_
,
_
=
tf
.
while_loop
(
beam_search_cond_fn
,
beam_search_cond_fn
,
beam_search_body_fn
,
beam_search_body_fn
,
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
),
(
cur_len
,
running_sequences
,
running_scores
,
sequences
,
scores
,
is_sent_finished
,
model_kwargs
,
input_ids_length
,
),
maximum_iterations
=
maximum_iterations
,
maximum_iterations
=
maximum_iterations
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment