Commit 10989715 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

rename variable

parent cf062905
......@@ -990,7 +990,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_scores, (batch_size, num_beams * vocab_size)
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment