Unverified Commit 3bbc2451 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

Small simplification to TopKLogitsWarper (#21130)

The max of top_k and min_tokens_to_keep performed on every call can just be done once up-front.
parent 0dde5897
......@@ -171,15 +171,14 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
topk = min(self.top_k, scores.shape[-1]) # Safety check
topk_scores, topk_indices = lax.top_k(scores, topk)
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
topk_scores_flat = topk_scores.flatten()
......
......@@ -262,12 +262,11 @@ class TopKLogitsWarper(LogitsWarper):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
......
......@@ -132,12 +132,11 @@ class TFTopKLogitsWarper(TFLogitsWarper):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
top_k = min(self.top_k, scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
next_scores = tf.where(indices_to_remove, self.filter_value, scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment