Unverified Commit 99c8226b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA repetition penalty (#16879)

parent ec81c11a
...@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): ...@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
self.penalty = penalty self.penalty = penalty
def _create_score_penalties(self, input_ids, logits): def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
# create logit penalties for already seen input_ids # We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
token_penalties = np.ones(logits.shape) # before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()] # the same token multiple times.
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id] # Gathers the penalties to apply
logit_penalties = np.zeros(logit_penalized.shape) logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
# if previous logit score is < 0 then multiply repetition penalty else divide logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
logit_penalties[logit_penalized < 0] = self.penalty logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
logit_penalties[logit_penalized > 0] = 1 / self.penalty
np.put(token_penalties[i], prev_input_id, logit_penalties) # Scatters the penalties
return tf.convert_to_tensor(token_penalties, dtype=tf.float32) token_penalties = tf.ones(logits.shape)
indexable_prev_input_ids = tf.concat(
(
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
),
axis=1,
)
token_penalties = tf.tensor_scatter_nd_update(
token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
)
return token_penalties
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores) score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
......
...@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :])) self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :])) self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
def test_repetition_penalty_dist_process(self): def _get_repetition_penalty_inputs(self):
vocab_size = 10 vocab_size = 10
cur_len = 2 cur_len = 2
...@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = tf.where(mask, -1 / vocab_size, scores) scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = tf.where(mask, 4 / vocab_size, scores) scores = tf.where(mask, 4 / vocab_size, scores)
return vocab_size, cur_len, input_ids, scores
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) def _check_repetition_penalty_outputs(self, scores, vocab_size):
# check that values were correctly changed (negative scores for used tokens should increase, others
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len) # should decrease)
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2) self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
def test_repetition_penalty_dist_process(self):
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
self._check_repetition_penalty_outputs(scores, vocab_size)
def test_repetition_penalty_dist_process_xla(self):
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) # added line wrt non-XLA test
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
self._check_repetition_penalty_outputs(scores, vocab_size)
def test_top_k_dist_warper(self): def test_top_k_dist_warper(self):
input_ids = None input_ids = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment