Unverified Commit 439f1cab authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] beam search should generate without replacement (#4845)

* fix flaky beam search

* fix typo
parent c0554776
...@@ -1145,8 +1145,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1145,8 +1145,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
_scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size)) _scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
next_tokens = tf.random.categorical( next_tokens = sample_without_replacement(
_scores, dtype=tf.int32, num_samples=2 * num_beams _scores, num_samples=2 * num_beams
) # (batch_size, 2 * num_beams) ) # (batch_size, 2 * num_beams)
# Compute next scores # Compute next scores
next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams) next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
...@@ -1736,6 +1736,17 @@ def shape_list(x): ...@@ -1736,6 +1736,17 @@ def shape_list(x):
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def sample_without_replacement(logits, num_samples):
"""
categorical sampling witouth replacement is currently not implemented
the gumbel-max trick will do for now
see https://github.com/tensorflow/tensorflow/issues/9260 for more info
"""
z = -tf.math.log(tf.random.uniform(shape_list(logits), 0, 1))
_, indices = tf.nn.top_k(logits + z, num_samples)
return indices
def get_initializer(initializer_range=0.02): def get_initializer(initializer_range=0.02):
"""Creates a `tf.initializers.truncated_normal` with the given range. """Creates a `tf.initializers.truncated_normal` with the given range.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment