"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c5037b459e117b9286c611092f38663f6cb763b0"
Unverified Commit eb5bdcdf authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF generate: handle case without cache in beam search (#16704)

parent 9c9db751
...@@ -2514,6 +2514,7 @@ class TFGenerationMixin: ...@@ -2514,6 +2514,7 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape batch_size, num_beams, cur_len = input_ids.shape
input_ids_length = cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences = tf.TensorArray( sequences = tf.TensorArray(
...@@ -2568,7 +2569,14 @@ class TFGenerationMixin: ...@@ -2568,7 +2569,14 @@ class TFGenerationMixin:
# 4. define "xla-compile-able" stop-condition and auto-regressive function # 4. define "xla-compile-able" stop-condition and auto-regressive function
# define stop-condition and auto-regressive function # define stop-condition and auto-regressive function
def beam_search_cond_fn( def beam_search_cond_fn(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
): ):
""" """
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
...@@ -2597,7 +2605,7 @@ class TFGenerationMixin: ...@@ -2597,7 +2605,7 @@ class TFGenerationMixin:
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs, model_kwargs,
input_ids_length=1, input_ids_length,
intermediary_running_sequences=None, intermediary_running_sequences=None,
): ):
""" """
...@@ -2754,9 +2762,11 @@ class TFGenerationMixin: ...@@ -2754,9 +2762,11 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input # if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None: if model_kwargs.get("past", None) is None:
input_ids_length = cur_len + 1 next_input_ids_length = cur_len + 1
# let's throw out `past` since we don't want `None` tensors # let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None) model_kwargs.pop("past", None)
else:
next_input_ids_length = 1
# 9. Prepare the `tf.TensorArray` for the next iteration # 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1])) next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
...@@ -2772,6 +2782,7 @@ class TFGenerationMixin: ...@@ -2772,6 +2782,7 @@ class TFGenerationMixin:
next_scores, next_scores,
next_is_sent_finished, next_is_sent_finished,
next_model_kwargs, next_model_kwargs,
next_input_ids_length,
) )
# 5. run generation # 5. run generation
...@@ -2780,8 +2791,7 @@ class TFGenerationMixin: ...@@ -2780,8 +2791,7 @@ class TFGenerationMixin:
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
) )
# 1st generation step has to be run before to initialize `past` # 1st generation step has to be run before to initialize `past` (if active)
beam_search_body_fn_first_iter = partial(beam_search_body_fn, input_ids_length=cur_len)
( (
cur_len, cur_len,
running_sequences, running_sequences,
...@@ -2790,20 +2800,44 @@ class TFGenerationMixin: ...@@ -2790,20 +2800,44 @@ class TFGenerationMixin:
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs, model_kwargs,
) = beam_search_body_fn_first_iter( input_ids_length,
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs ) = beam_search_body_fn(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
) )
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though) # NOT yield EOS token though)
if beam_search_cond_fn( if beam_search_cond_fn(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
): ):
maximum_iterations = max_length - cur_len maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop( cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
beam_search_cond_fn, beam_search_cond_fn,
beam_search_body_fn, beam_search_body_fn,
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs), (
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment