Unverified Commit baab5e7c authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF generate refactor - Sample (#15793)



* Add TF logits wrappers 

* Add sample method

* add tests for TF logit wrappers

* TF generate sample tests now run on CPU
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 96ae92be
...@@ -154,6 +154,18 @@ generation. ...@@ -154,6 +154,18 @@ generation.
[[autodoc]] TFLogitsProcessorList [[autodoc]] TFLogitsProcessorList
- __call__ - __call__
[[autodoc]] TFLogitsWarper
- __call__
[[autodoc]] TFTemperatureLogitsWarper
- __call__
[[autodoc]] TFTopPLogitsWarper
- __call__
[[autodoc]] TFTopKLogitsWarper
- __call__
[[autodoc]] TFMinLengthLogitsProcessor [[autodoc]] TFMinLengthLogitsProcessor
- __call__ - __call__
......
...@@ -1656,10 +1656,14 @@ if is_tf_available(): ...@@ -1656,10 +1656,14 @@ if is_tf_available():
_import_structure["generation_tf_logits_process"] = [ _import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor", "TFLogitsProcessor",
"TFLogitsProcessorList", "TFLogitsProcessorList",
"TFLogitsWarper",
"TFMinLengthLogitsProcessor", "TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor", "TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor", "TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
] ]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] _import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
...@@ -3706,10 +3710,14 @@ if TYPE_CHECKING: ...@@ -3706,10 +3710,14 @@ if TYPE_CHECKING:
from .generation_tf_logits_process import ( from .generation_tf_logits_process import (
TFLogitsProcessor, TFLogitsProcessor,
TFLogitsProcessorList, TFLogitsProcessorList,
TFLogitsWarper,
TFMinLengthLogitsProcessor, TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor, TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor, TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor, TFRepetitionPenaltyLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
) )
from .generation_tf_utils import tf_top_k_top_p_filtering from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback from .keras_callbacks import KerasMetricCallback, PushToHubCallback
......
...@@ -94,7 +94,7 @@ class FlaxLogitsProcessorList(list): ...@@ -94,7 +94,7 @@ class FlaxLogitsProcessorList(list):
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
Args: Args:
temperature (`float`): temperature (`float`):
...@@ -114,7 +114,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): ...@@ -114,7 +114,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
class FlaxTopPLogitsWarper(FlaxLogitsWarper): class FlaxTopPLogitsWarper(FlaxLogitsWarper):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
Args: Args:
top_p (`float`): top_p (`float`):
...@@ -155,7 +155,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): ...@@ -155,7 +155,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
class FlaxTopKLogitsWarper(FlaxLogitsWarper): class FlaxTopKLogitsWarper(FlaxLogitsWarper):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
Args: Args:
top_k (`int`): top_k (`int`):
......
...@@ -326,7 +326,7 @@ class FlaxGenerationMixin: ...@@ -326,7 +326,7 @@ class FlaxGenerationMixin:
raise NotImplementedError("`Beam sampling is currently not implemented.") raise NotImplementedError("`Beam sampling is currently not implemented.")
def _get_logits_warper( def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> FlaxLogitsProcessorList: ) -> FlaxLogitsProcessorList:
""" """
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`] This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
......
...@@ -58,6 +58,17 @@ class TFLogitsProcessor: ...@@ -58,6 +58,17 @@ class TFLogitsProcessor:
) )
class TFLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
"""TF method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class TFLogitsProcessorList(list): class TFLogitsProcessorList(list):
""" """
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor. This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
...@@ -81,6 +92,109 @@ class TFLogitsProcessorList(list): ...@@ -81,6 +92,109 @@ class TFLogitsProcessorList(list):
return scores return scores
class TFTemperatureLogitsWarper(TFLogitsWarper):
r"""
[`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
scores = scores / self.temperature
return scores
class TFTopKLogitsWarper(TFLogitsWarper):
r"""
[`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
return next_scores
class TFTopPLogitsWarper(TFLogitsWarper):
"""
[`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
Args:
top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept
for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
mask_scores = tf.fill(scores.shape, self.filter_value)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(topk_scores, axis=-1), axis=-1)
score_mask = cumulative_probs < self.top_p
# Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
# Ensure min tokens to keep
score_mask = tf.concat(
(
tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
score_mask[:, self.min_tokens_to_keep :],
),
axis=-1,
)
# Mask the values that do not fit the criteria
topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
# Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
# to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
# can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
return next_scores
class TFMinLengthLogitsProcessor(TFLogitsProcessor): class TFMinLengthLogitsProcessor(TFLogitsProcessor):
r""" r"""
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0. [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
......
This diff is collapsed.
...@@ -556,8 +556,8 @@ class GenerationMixin: ...@@ -556,8 +556,8 @@ class GenerationMixin:
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
expand_size: int = 1, expand_size: int = 1,
is_encoder_decoder: bool = False, is_encoder_decoder: bool = False,
attention_mask: torch.LongTensor = None, attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: ModelOutput = None, encoder_outputs: Optional[ModelOutput] = None,
**model_kwargs, **model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]: ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = ( expanded_return_idx = (
...@@ -617,11 +617,11 @@ class GenerationMixin: ...@@ -617,11 +617,11 @@ class GenerationMixin:
def _get_logits_warper( def _get_logits_warper(
self, self,
top_k: int = None, top_k: Optional[int] = None,
top_p: float = None, top_p: Optional[float] = None,
typical_p: float = None, typical_p: Optional[float] = None,
temperature: float = None, temperature: Optional[float] = None,
num_beams: int = None, num_beams: Optional[int] = None,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
......
...@@ -31,6 +31,13 @@ class TFLogitsProcessorList(metaclass=DummyObject): ...@@ -31,6 +31,13 @@ class TFLogitsProcessorList(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFMinLengthLogitsProcessor(metaclass=DummyObject): class TFMinLengthLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
...@@ -59,6 +66,27 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): ...@@ -59,6 +66,27 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFTemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFTopKLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFTopPLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
def tf_top_k_top_p_filtering(*args, **kwargs): def tf_top_k_top_p_filtering(*args, **kwargs):
requires_backends(tf_top_k_top_p_filtering, ["tf"]) requires_backends(tf_top_k_top_p_filtering, ["tf"])
......
...@@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -51,7 +51,7 @@ class LogitsProcessorTest(unittest.TestCase):
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return scores return scores
def test_min_lenght_dist_processor(self): def test_min_length_dist_processor(self):
vocab_size = 20 vocab_size = 20
batch_size = 4 batch_size = 4
eos_token_id = 0 eos_token_id = 0
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import unittest import unittest
import numpy as np
from transformers import is_tf_available from transformers import is_tf_available
from transformers.testing_utils import require_tf from transformers.testing_utils import require_tf
...@@ -29,6 +31,9 @@ if is_tf_available(): ...@@ -29,6 +31,9 @@ if is_tf_available():
TFNoBadWordsLogitsProcessor, TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor, TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor, TFRepetitionPenaltyLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
) )
from transformers.tf_utils import set_tensor_by_indices_to_value from transformers.tf_utils import set_tensor_by_indices_to_value
...@@ -38,7 +43,7 @@ if is_tf_available(): ...@@ -38,7 +43,7 @@ if is_tf_available():
@require_tf @require_tf
class TFLogitsProcessorTest(unittest.TestCase): class TFLogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int): def _get_uniform_logits(self, batch_size: int, length: int):
scores = tf.ones((batch_size, length), dtype=tf.float32) / length scores = np.ones((batch_size, length), dtype=np.float32) / length
return scores return scores
def test_min_length_dist_processor(self): def test_min_length_dist_processor(self):
...@@ -60,6 +65,37 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -60,6 +65,37 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_before_min_length = min_dist_processor(input_ids, scores) scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy()) self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = tf.nn.softmax(scores, axis=-1)
temp_dist_warper_sharper = TFTemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TFTemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
# uniform distribution stays uniform
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
tf.debugging.assert_near(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)
# sharp peaks get higher, valleys get lower
self.assertLess(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_sharp[1, :]))
self.assertGreater(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_sharp[1, :]))
# smooth peaks get lower, valleys get higher
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
def test_repetition_penalty_dist_process(self): def test_repetition_penalty_dist_process(self):
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
vocab_size = 10 vocab_size = 10
...@@ -82,6 +118,73 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -82,6 +118,73 @@ class TFLogitsProcessorTest(unittest.TestCase):
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = TFTopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
# check that correct tokens are filtered
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(tf.math.is_inf(scores[1]).numpy().tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
top_k_warp_safety_check = TFTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
scores = top_k_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
top_p_warp = TFTopPLogitsWarper(0.7)
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
# check edge cases with negative and extreme logits
ramp_logits = np.broadcast_to(
np.arange(vocab_size, dtype=np.float32)[None, :], (batch_size, vocab_size)
).copy() - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# 2.
self.assertListEqual(
tf.math.reduce_sum(tf.where(filtered_dist != 0.0, 1, 0), axis=-1).numpy().tolist(), [3, 2]
)
def test_no_repeat_ngram_dist_processor(self): def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3 vocab_size = 3
batch_size = 2 batch_size = 2
...@@ -140,13 +243,19 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -140,13 +243,19 @@ class TFLogitsProcessorTest(unittest.TestCase):
# instantiate all dist processors # instantiate all dist processors
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TFTopKLogitsWarper(3)
top_p_warp = TFTopPLogitsWarper(0.8)
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list # no processor list
scores = min_dist_proc(input_ids, scores) scores = min_dist_proc(input_ids, scores)
scores = temp_dist_warp(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores) scores = rep_penalty_proc(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores) scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores) scores = no_bad_words_dist_proc(input_ids, scores)
...@@ -154,7 +263,10 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -154,7 +263,10 @@ class TFLogitsProcessorTest(unittest.TestCase):
processor = TFLogitsProcessorList( processor = TFLogitsProcessorList(
[ [
min_dist_proc, min_dist_proc,
temp_dist_warp,
rep_penalty_proc, rep_penalty_proc,
top_k_warp,
top_p_warp,
no_repeat_proc, no_repeat_proc,
no_bad_words_dist_proc, no_bad_words_dist_proc,
] ]
......
...@@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -488,9 +488,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"top_k": 500, "top_k": 500,
"top_p": 0.9, "top_p": 0.9,
} }
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [ expected_output_string = [
......
...@@ -497,8 +497,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): ...@@ -497,8 +497,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"top_k": 500, "top_k": 500,
"top_p": 0.9, "top_p": 0.9,
} }
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
......
...@@ -947,7 +947,7 @@ class TFModelTesterMixin: ...@@ -947,7 +947,7 @@ class TFModelTesterMixin:
if config.bos_token_id is None: if config.bos_token_id is None:
# if bos token id is not defined model needs input_ids # if bos token id is not defined model needs input_ids
with self.assertRaises(AssertionError): with self.assertRaises(ValueError):
model.generate(do_sample=True, max_length=5) model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1 # num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True)) self._check_generated_ids(model.generate(input_ids, do_sample=True))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment