Unverified Commit 809dac48 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA logits processors - minimum length, forced eos, and forced bos (#16912)



* XLA min len, forced eos, and forced bos
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent f6210c49
......@@ -215,13 +215,18 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
# generate is not XLA - compileable anyways
if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
return scores
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# applies eos token masking if the first argument is true
scores = tf.cond(
tf.less(cur_len, self.min_length),
lambda: self._apply_eos_token_mask(scores),
lambda: tf.identity(scores),
)
return scores
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
from parameterized import parameterized
from transformers import is_tf_available
from transformers.testing_utils import require_tf
......@@ -47,12 +48,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
return scores
def test_min_length_dist_processor(self):
@parameterized.expand([(False,), (True,)])
def test_min_length_dist_processor(self, use_xla):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
if use_xla:
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
# check that min length is applied at length 5
cur_len = 5
......@@ -256,12 +260,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
[[True, True, False, True, True], [True, True, True, False, True]],
)
def test_forced_bos_token_logits_processor(self):
@parameterized.expand([(False,), (True,)])
def test_forced_bos_token_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# check that all scores are -inf except the bos_token_id score
cur_len = 1
......@@ -280,13 +287,16 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
def test_forced_eos_token_logits_processor(self):
@parameterized.expand([(False,), (True,)])
def test_forced_eos_token_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
cur_len = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment