"docs/vscode:/vscode.git/clone" did not exist on "8cf4a6f0a63ed3aeed68192a9304fed2bd0ce100"
Unverified Commit fd9aa82b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Fix generation repetition penalty with XLA (#18648)

parent 81ab1112
......@@ -262,9 +262,11 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
# Scatters the penalties
token_penalties = tf.ones(logits.shape)
batch_size = input_ids.shape[0]
seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
indexable_prev_input_ids = tf.concat(
(
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
),
axis=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment