"docs/vscode:/vscode.git/clone" did not exist on "0842c33edd5df349daddcbedb914d63e342d3c3d"
Unverified Commit d7f7f29f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: remove set_tensor_by_indices_to_value (#16729)

parent a315988b
...@@ -19,7 +19,6 @@ from typing import List ...@@ -19,7 +19,6 @@ from typing import List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .tf_utils import set_tensor_by_indices_to_value
from .utils import add_start_docstrings from .utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
...@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor): ...@@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
# generate is not XLA - compileable anyways # generate is not XLA - compileable anyways
if cur_len < self.min_length: if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape) eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf")) scores = tf.where(eos_token_id_mask, float("-inf"), scores)
return scores return scores
...@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor): ...@@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
[True if token in banned_tokens_slice else False for token in range(vocab_size)] [True if token in banned_tokens_slice else False for token in range(vocab_size)]
) )
scores = set_tensor_by_indices_to_value( scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
return scores return scores
...@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor): ...@@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
[True if token in banned_tokens_slice else False for token in range(vocab_size)] [True if token in banned_tokens_slice else False for token in range(vocab_size)]
) )
scores = set_tensor_by_indices_to_value( scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
return scores return scores
......
...@@ -34,7 +34,7 @@ from .generation_tf_logits_process import ( ...@@ -34,7 +34,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
) )
from .tf_utils import set_tensor_by_indices_to_value, shape_list from .tf_utils import shape_list
from .utils import ModelOutput, logging from .utils import ModelOutput, logging
...@@ -952,8 +952,7 @@ class TFGenerationMixin: ...@@ -952,8 +952,7 @@ class TFGenerationMixin:
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
scores = tf.where(eos_token_indices_mask, -float("inf"), scores)
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
...@@ -969,8 +968,8 @@ class TFGenerationMixin: ...@@ -969,8 +968,8 @@ class TFGenerationMixin:
[True if token in banned_tokens_slice else False for token in range(vocab_size)] [True if token in banned_tokens_slice else False for token in range(vocab_size)]
) )
scores = set_tensor_by_indices_to_value( scores = tf.where(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
) )
if bad_words_ids is not None: if bad_words_ids is not None:
...@@ -983,8 +982,8 @@ class TFGenerationMixin: ...@@ -983,8 +982,8 @@ class TFGenerationMixin:
[True if token in banned_tokens_slice else False for token in range(vocab_size)] [True if token in banned_tokens_slice else False for token in range(vocab_size)]
) )
scores = set_tensor_by_indices_to_value( scores = tf.where(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
) )
assert shape_list(scores) == [batch_size * num_beams, vocab_size] assert shape_list(scores) == [batch_size * num_beams, vocab_size]
...@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ...@@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None] indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) logits = tf.where(indices_to_remove, filter_value, logits)
if top_p < 1.0: if top_p < 1.0:
sorted_indices = tf.argsort(logits, direction="DESCENDING") sorted_indices = tf.argsort(logits, direction="DESCENDING")
sorted_logits = tf.gather( sorted_logits = tf.gather(
...@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ...@@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
) )
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices) indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) logits = tf.where(indices_to_remove, filter_value, logits)
return logits return logits
......
...@@ -23,11 +23,6 @@ from .utils import logging ...@@ -23,11 +23,6 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def set_tensor_by_indices_to_value(tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]):
# create value_tensor since tensor value assignment is not possible in TF
return tf.where(indices, value, tensor)
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
""" """
Deal with dynamic shape in tensorflow cleanly. Deal with dynamic shape in tensorflow cleanly.
......
...@@ -37,7 +37,6 @@ if is_tf_available(): ...@@ -37,7 +37,6 @@ if is_tf_available():
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
) )
from transformers.tf_utils import set_tensor_by_indices_to_value
from ..test_modeling_tf_common import ids_tensor from ..test_modeling_tf_common import ids_tensor
...@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size=2, length=vocab_size) scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool) mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size) scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size) scores = tf.where(mask, 4 / vocab_size, scores)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
...@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# remove inf # remove inf
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9) scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9) scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)
# scores should be equal # scores should be equal
tf.debugging.assert_near(scores, scores_comp, atol=1e-3) tf.debugging.assert_near(scores, scores_comp, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment