Unverified Commit fd9aa82b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Fix generation repetition penalty with XLA (#18648)

parent 81ab1112
...@@ -262,9 +262,11 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): ...@@ -262,9 +262,11 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
# Scatters the penalties # Scatters the penalties
token_penalties = tf.ones(logits.shape) token_penalties = tf.ones(logits.shape)
batch_size = input_ids.shape[0]
seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
indexable_prev_input_ids = tf.concat( indexable_prev_input_ids = tf.concat(
( (
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1), tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1), tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
), ),
axis=1, axis=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment