Unverified Commit 99c8226b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA repetition penalty (#16879)

parent ec81c11a
......@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
self.penalty = penalty
def _create_score_penalties(self, input_ids, logits):
# create logit penalties for already seen input_ids
token_penalties = np.ones(logits.shape)
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
logit_penalties = np.zeros(logit_penalized.shape)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalties[logit_penalized < 0] = self.penalty
logit_penalties[logit_penalized > 0] = 1 / self.penalty
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
# We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
# before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
# the same token multiple times.
# Gathers the penalties to apply
logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
# Scatters the penalties
token_penalties = tf.ones(logits.shape)
indexable_prev_input_ids = tf.concat(
(
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
),
axis=1,
)
token_penalties = tf.tensor_scatter_nd_update(
token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
)
return token_penalties
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
......
......@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
def test_repetition_penalty_dist_process(self):
def _get_repetition_penalty_inputs(self):
vocab_size = 10
cur_len = 2
......@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = tf.where(mask, 4 / vocab_size, scores)
return vocab_size, cur_len, input_ids, scores
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
# check that values were correctly changed
def _check_repetition_penalty_outputs(self, scores, vocab_size):
# check that values were correctly changed (negative scores for used tokens should increase, others
# should decrease)
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
def test_repetition_penalty_dist_process(self):
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
self._check_repetition_penalty_outputs(scores, vocab_size)
def test_repetition_penalty_dist_process_xla(self):
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) # added line wrt non-XLA test
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
self._check_repetition_penalty_outputs(scores, vocab_size)
def test_top_k_dist_warper(self):
input_ids = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment