Unverified Commit fb0ae129 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: XLA bad words logits processor and list of processors (#16974)

parent 57e6464a
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import List from typing import List, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" ...@@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`): scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search search or log softmax for each vocabulary token when using beam search.
cur_len (`int`):
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
is the maximum length generate can produce, and we need to know which of its tokens are valid.
kwargs: kwargs:
Additional logits processor specific kwargs. Additional logits processor specific kwargs.
...@@ -51,7 +54,7 @@ class TFLogitsProcessor: ...@@ -51,7 +54,7 @@ class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation.""" """Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for processing logits.""" """TF method for processing logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
...@@ -62,7 +65,7 @@ class TFLogitsWarper: ...@@ -62,7 +65,7 @@ class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
"""TF method for warping logits.""" """TF method for warping logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
...@@ -77,18 +80,18 @@ class TFLogitsProcessorList(list): ...@@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
""" """
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
for processor in self: for processor in self:
function_args = inspect.signature(processor.__call__).parameters function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2: if len(function_args) > 3:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]): if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError( raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for " f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor." f"{processor.__class__} are passed to the logits processor."
) )
scores = processor(input_ids, scores, **kwargs) scores = processor(input_ids, scores, cur_len, **kwargs)
else: else:
scores = processor(input_ids, scores) scores = processor(input_ids, scores, cur_len)
return scores return scores
...@@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper): ...@@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper):
self.temperature = temperature self.temperature = temperature
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
scores = scores / self.temperature scores = scores / self.temperature
return scores return scores
...@@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper): ...@@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k # Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:] indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
...@@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper): ...@@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1]) topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
mask_scores = tf.fill(scores.shape, self.filter_value) mask_scores = tf.fill(scores.shape, self.filter_value)
...@@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor): ...@@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
) )
self.bad_words_ids = bad_words_ids # stores the information about bad words in three tensors:
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
def calc_banned_bad_words_ids(self, prev_input_ids): self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
banned_tokens = [] # 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
def _tokens_match(prev_tokens, tokens): if any([word_len == 0 for word_len in bad_word_seqs_len]):
if len(tokens) == 0: raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
# if bad word tokens is just one token always ban it self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
return True # 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
if len(tokens) > len(prev_tokens): self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
# if bad word tokens are longer than prev tokens they can't be equal
return False def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
def _tokens_match(bad_word_seq_number):
if prev_tokens[-len(tokens) :] == tokens: def _len_one():
# if tokens match # If the bad sequence only has one token, always mask it
return True return tf.cond(
else: tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
return False lambda: tf.ones((), dtype=tf.bool),
_len_greater_than_cur_len,
for prev_input_ids_slice in prev_input_ids: )
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
assert (
len(banned_token_seq) > 0
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False: def _len_greater_than_cur_len():
# if tokens do not match continue # Otherwise, if the bad sequence is longer than the current length they can't ever match
continue return tf.cond(
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
lambda: tf.zeros((), dtype=tf.bool),
_match_found,
)
banned_tokens_slice.append(banned_token_seq[-1]) def _match_found():
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
# an answer (otherwise we get indexing exceptions)
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
return tf.cond(
tf.math.reduce_all(
tf.math.equal(
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
)
),
lambda: tf.ones((), dtype=tf.bool),
lambda: tf.zeros((), dtype=tf.bool),
)
banned_tokens.append(banned_tokens_slice) match = _len_one()
return match
return banned_tokens # Compares the current row against all bad word sequences, obtaining a mask with the matches.
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
return row_banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
vocab_size = scores.shape[-1] # `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
# To remain simple and XLA-compatible, we work on a per-row fashion.
# calculate a list of banned tokens according to bad words # TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len]) # a frequent choke point. (make `cur_len` a tensor?)
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
banned_tokens_indices_mask = [] row_input_ids, row_score = row_inputs
for banned_tokens_slice in banned_tokens: banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
banned_tokens_indices_mask.append( banned_tokens_mask = tf.scatter_nd(
[True if token in banned_tokens_slice else False for token in range(vocab_size)] indices=tf.expand_dims(banned_tokens, axis=-1),
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
shape=row_score.shape,
) )
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
return row_score
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores) scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
return scores return scores
...@@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor): ...@@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
# https://github.com/huggingface/transformers/pull/16974
if not tf.executing_eagerly():
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
batch_size, vocab_size = scores.shape batch_size, vocab_size = scores.shape
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len) banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
......
...@@ -2030,7 +2030,7 @@ class TFGenerationMixin: ...@@ -2030,7 +2030,7 @@ class TFGenerationMixin:
if not use_xla: if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size)) input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]]) input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0]) next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
# argmax # argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
...@@ -2301,8 +2301,8 @@ class TFGenerationMixin: ...@@ -2301,8 +2301,8 @@ class TFGenerationMixin:
if not use_xla: if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size)) input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[:cur_len]) input_ids = tf.transpose(input_ids[:cur_len])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores) next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
# sample # sample
if seed is not None: if seed is not None:
...@@ -2726,7 +2726,7 @@ class TFGenerationMixin: ...@@ -2726,7 +2726,7 @@ class TFGenerationMixin:
# add new logprobs to existing running logprobs scores. # add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits) log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor( log_probs = logits_processor(
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
) )
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2) log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
......
...@@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_temperature_dist_warper(self, use_xla): def test_temperature_dist_warper(self, use_xla):
input_ids = None input_ids = None
cur_len = None
length = 20 length = 20
scores = self._get_uniform_logits(batch_size=2, length=length) scores = self._get_uniform_logits(batch_size=2, length=length)
...@@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True) temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True) temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1) warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1)
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1) warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1)
# uniform distribution stays uniform # uniform distribution stays uniform
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3) tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
...@@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_top_k_dist_warper(self, use_xla): def test_top_k_dist_warper(self, use_xla):
input_ids = None input_ids = None
cur_len = None
vocab_size = 10 vocab_size = 10
batch_size = 2 batch_size = 2
...@@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
if use_xla: if use_xla:
top_k_warp = tf.function(top_k_warp, jit_compile=True) top_k_warp = tf.function(top_k_warp, jit_compile=True)
scores = top_k_warp(input_ids, ramp_logits) scores = top_k_warp(input_ids, ramp_logits, cur_len)
# check that correct tokens are filtered # check that correct tokens are filtered
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False]) self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
...@@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
if use_xla: if use_xla:
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True) top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
scores = top_k_warp_safety_check(input_ids, logits) scores = top_k_warp_safety_check(input_ids, logits, cur_len)
# uniform dist is not changed # uniform dist is not changed
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0]) self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy() ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits) scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2]) self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
...@@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_top_p_dist_warper(self, use_xla): def test_top_p_dist_warper(self, use_xla):
input_ids = None input_ids = None
cur_len = None
vocab_size = 10 vocab_size = 10
batch_size = 2 batch_size = 2
...@@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
top_p_warp = TFTopPLogitsWarper(0.7) top_p_warp = TFTopPLogitsWarper(0.7)
if use_xla: if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True) top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = tf.exp(top_p_warp(input_ids, dist)) filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
# dist should be filtered to keep min num values so that sum is >= 0.7 # dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0 # exp (-inf) => 0
...@@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
if use_xla: if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True) top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = top_p_warp(input_ids, ramp_logits) filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# 2. # 2.
...@@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]] tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
) )
def test_no_bad_words_dist_processor(self): @parameterized.expand([(False,), (True,)])
def test_no_bad_words_dist_processor(self, use_xla):
vocab_size = 5 vocab_size = 5
batch_size = 2 batch_size = 2
eos_token_id = 4 eos_token_id = 4
...@@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size, vocab_size) scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
if use_xla:
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len) filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
...@@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len) scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores)))) self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
def test_processor_list(self): @parameterized.expand([(False,), (True,)])
def test_processor_list(self, use_xla):
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
batch_size = 4 batch_size = 4
cur_len = 10 cur_len = 10
vocab_size = 15 vocab_size = 15
...@@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TFTopKLogitsWarper(3) top_k_warp = TFTopKLogitsWarper(3)
top_p_warp = TFTopPLogitsWarper(0.8) top_p_warp = TFTopPLogitsWarper(0.8)
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) # no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
if use_xla:
min_dist_proc = tf.function(min_dist_proc, jit_compile=True)
temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True)
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
top_k_warp = tf.function(top_k_warp, jit_compile=True)
top_p_warp = tf.function(top_p_warp, jit_compile=True)
# no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
# no processor list # no processor list
scores = min_dist_proc(input_ids, scores, cur_len) scores = min_dist_proc(input_ids, scores, cur_len)
scores = temp_dist_warp(input_ids, scores) scores = temp_dist_warp(input_ids, scores, cur_len)
scores = rep_penalty_proc(input_ids, scores, cur_len) scores = rep_penalty_proc(input_ids, scores, cur_len)
scores = top_k_warp(input_ids, scores) scores = top_k_warp(input_ids, scores, cur_len)
scores = top_p_warp(input_ids, scores) scores = top_p_warp(input_ids, scores, cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len) # scores = no_repeat_proc(input_ids, scores, cur_len)
scores = no_bad_words_dist_proc(input_ids, scores, cur_len) scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
# with processor list # with processor list
...@@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
rep_penalty_proc, rep_penalty_proc,
top_k_warp, top_k_warp,
top_p_warp, top_p_warp,
no_repeat_proc, # no_repeat_proc,
no_bad_words_dist_proc, no_bad_words_dist_proc,
] ]
) )
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) scores_comp = processor(input_ids, scores_comp, cur_len)
# remove inf # remove inf
scores = tf.where(tf.math.is_inf(scores), -1e9, scores) scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment