"git@developer.sourcefind.cn:OpenDAS/mmpretrain-mmcv.git" did not exist on "29c24e9a5d312d14dfe92d04e53181cff317215d"
Commit 7a89a3e4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct beam search sampling

parent c4c4c999
...@@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -760,9 +760,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
] ]
# scores for each sentence in the beam # scores for each sentence in the beam
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) if do_sample is False:
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,)) beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
else:
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states # cache compute states
past = None past = None
...@@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -790,23 +795,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0: if temperature != 1.0:
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering # Top-p/top-k filtering
next_token_logits = tf_top_k_top_p_filtering( _scores = tf_top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
_scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
next_tokens = tf.random.categorical( next_tokens = tf.random.categorical(
next_token_logits, dtype=tf.int32, num_samples=2 _scores, dtype=tf.int32, num_samples=2 * num_beams
) # (batch_size * num_beams, vocab_size) ) # (batch_size, 2 * num_beams)
# Compute next scores # Compute next scores
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, 2)
) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
else: else:
# do greedy beam search # do greedy beam search
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment