Unverified Commit 3ac9945e authored by Xin Qiu's avatar Xin Qiu Committed by GitHub
Browse files

Fix beam score calculation issue for Tensorflow version (#27814)

* Fix beam score calculation issue for tensorflow version

* fix transition score computation error

* make cur_len represent the entire sequence length including decoder prompt
parent 4c5ed1d0
...@@ -2268,6 +2268,8 @@ class TFGenerationMixin: ...@@ -2268,6 +2268,8 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = shape_list(input_ids) batch_size, num_beams, cur_len = shape_list(input_ids)
# store the prompt length of decoder
decoder_prompt_len = cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * ( input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
...@@ -2286,8 +2288,8 @@ class TFGenerationMixin: ...@@ -2286,8 +2288,8 @@ class TFGenerationMixin:
scores = tf.ones((batch_size, num_beams)) * -1.0e9 scores = tf.ones((batch_size, num_beams)) * -1.0e9
# per batch beam indices # per batch beam indices
running_beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1 running_beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1 beam_indices = tf.ones((batch_size, num_beams, max_length - decoder_prompt_len), dtype=tf.int32) * -1
# flatten beam dim # flatten beam dim
if "encoder_outputs" in model_kwargs: if "encoder_outputs" in model_kwargs:
...@@ -2308,6 +2310,7 @@ class TFGenerationMixin: ...@@ -2308,6 +2310,7 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
decoder_prompt_len,
model_kwargs, model_kwargs,
): ):
""" """
...@@ -2318,15 +2321,17 @@ class TFGenerationMixin: ...@@ -2318,15 +2321,17 @@ class TFGenerationMixin:
not_max_length_yet = cur_len < max_length not_max_length_yet = cur_len < max_length
# 2. can the new beams still improve? # 2. can the new beams still improve?
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. See the discussion
# below for more details. # below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0: if early_stopping == "never" and length_penalty > 0.0:
best_running_score = running_scores[:, :1] / (max_length**length_penalty) best_running_score = running_scores[:, :1] / ((max_length - decoder_prompt_len) ** length_penalty)
else: else:
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) best_running_score = running_scores[:, :1] / (
tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty
)
worst_finished_score = tf.where( worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
) )
...@@ -2346,6 +2351,7 @@ class TFGenerationMixin: ...@@ -2346,6 +2351,7 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
decoder_prompt_len,
model_kwargs, model_kwargs,
): ):
""" """
...@@ -2387,7 +2393,9 @@ class TFGenerationMixin: ...@@ -2387,7 +2393,9 @@ class TFGenerationMixin:
if output_scores: if output_scores:
all_scores.append( all_scores.append(
logits_warper( logits_warper(
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs_processed), cur_len flatten_beam_dim(running_sequences),
flatten_beam_dim(log_probs_processed),
cur_len,
) )
) )
if output_attentions and self.config.is_encoder_decoder: if output_attentions and self.config.is_encoder_decoder:
...@@ -2439,6 +2447,14 @@ class TFGenerationMixin: ...@@ -2439,6 +2447,14 @@ class TFGenerationMixin:
batch_modified_indices = topk_current_beam_indices + tf.broadcast_to( batch_modified_indices = topk_current_beam_indices + tf.broadcast_to(
tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape
) )
update_indices = tf.stack(
[
indices_batch,
indices_beam,
tf.broadcast_to(cur_len - decoder_prompt_len, [batch_size * beams_to_keep]),
],
axis=-1,
)
topk_beam_indices = tf.tensor_scatter_nd_update( topk_beam_indices = tf.tensor_scatter_nd_update(
tensor=topk_running_beam_indices, tensor=topk_running_beam_indices,
indices=update_indices, indices=update_indices,
...@@ -2455,7 +2471,8 @@ class TFGenerationMixin: ...@@ -2455,7 +2471,8 @@ class TFGenerationMixin:
eos_in_next_token = tf.math.reduce_any( eos_in_next_token = tf.math.reduce_any(
tf.equal( tf.equal(
tf.broadcast_to( tf.broadcast_to(
topk_sequences[:, :, cur_len], [len(eos_token_id)] + topk_sequences[:, :, cur_len].shape topk_sequences[:, :, cur_len],
[len(eos_token_id)] + topk_sequences[:, :, cur_len].shape,
), ),
tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1), tf.expand_dims(tf.expand_dims(eos_token_id, -1), -1),
), ),
...@@ -2483,7 +2500,9 @@ class TFGenerationMixin: ...@@ -2483,7 +2500,9 @@ class TFGenerationMixin:
# - add length penalty # - add length penalty
# - make sure no scores can be added anymore if beam is full # - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam # - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) topk_log_probs = topk_log_probs / (
tf.cast(cur_len + 1 - decoder_prompt_len, dtype=tf.float32) ** length_penalty
)
beams_in_batch_are_full = tf.broadcast_to( beams_in_batch_are_full = tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
) & (early_stopping is True) ) & (early_stopping is True)
...@@ -2546,6 +2565,7 @@ class TFGenerationMixin: ...@@ -2546,6 +2565,7 @@ class TFGenerationMixin:
next_scores, next_scores,
next_beam_indices, next_beam_indices,
next_is_sent_finished, next_is_sent_finished,
decoder_prompt_len,
next_model_kwargs, next_model_kwargs,
) )
...@@ -2560,6 +2580,7 @@ class TFGenerationMixin: ...@@ -2560,6 +2580,7 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
decoder_prompt_len,
model_kwargs, model_kwargs,
) = beam_search_body_fn( ) = beam_search_body_fn(
cur_len, cur_len,
...@@ -2570,6 +2591,7 @@ class TFGenerationMixin: ...@@ -2570,6 +2591,7 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
decoder_prompt_len,
model_kwargs, model_kwargs,
) )
...@@ -2585,6 +2607,7 @@ class TFGenerationMixin: ...@@ -2585,6 +2607,7 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
decoder_prompt_len,
_, _,
) = tf.while_loop( ) = tf.while_loop(
beam_search_cond_fn, beam_search_cond_fn,
...@@ -2598,6 +2621,7 @@ class TFGenerationMixin: ...@@ -2598,6 +2621,7 @@ class TFGenerationMixin:
scores, scores,
beam_indices, beam_indices,
is_sent_finished, is_sent_finished,
decoder_prompt_len,
model_kwargs, model_kwargs,
), ),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
...@@ -2611,7 +2635,7 @@ class TFGenerationMixin: ...@@ -2611,7 +2635,7 @@ class TFGenerationMixin:
beam_indices = tf.where(none_finished[:, None, None], beam_indices, running_beam_indices) beam_indices = tf.where(none_finished[:, None, None], beam_indices, running_beam_indices)
# Apply the length penalty so that running scores match the finalized scores if they are used # Apply the length penalty so that running scores match the finalized scores if they are used
running_scores = running_scores / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) running_scores = running_scores / (tf.cast(cur_len - decoder_prompt_len, dtype=tf.float32) ** length_penalty)
scores = tf.where(none_finished[:, None], scores, running_scores) scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in descending order) # Take best beams for each batch (the score is sorted in descending order)
...@@ -2622,7 +2646,7 @@ class TFGenerationMixin: ...@@ -2622,7 +2646,7 @@ class TFGenerationMixin:
if not use_xla: if not use_xla:
# Cut for backward compatibility # Cut for backward compatibility
sequences = sequences[:, :cur_len] sequences = sequences[:, :cur_len]
beam_indices = beam_indices[:, :cur_len] beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment