"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d066c3731bed1755f93ea64f0f00981b805532de"
Unverified Commit 3bbc2451 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

Small simplification to TopKLogitsWarper (#21130)

The max of top_k and min_tokens_to_keep performed on every call can just be done once up-front.
parent 0dde5897
...@@ -171,15 +171,14 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): ...@@ -171,15 +171,14 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
batch_size, vocab_size = scores.shape batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check topk = min(self.top_k, scores.shape[-1]) # Safety check
topk_scores, topk_indices = lax.top_k(scores, topk) topk_scores, topk_indices = lax.top_k(scores, topk)
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten() shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
topk_scores_flat = topk_scores.flatten() topk_scores_flat = topk_scores.flatten()
......
...@@ -262,12 +262,11 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -262,12 +262,11 @@ class TopKLogitsWarper(LogitsWarper):
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value) scores = scores.masked_fill(indices_to_remove, self.filter_value)
......
...@@ -132,12 +132,11 @@ class TFTopKLogitsWarper(TFLogitsWarper): ...@@ -132,12 +132,11 @@ class TFTopKLogitsWarper(TFLogitsWarper):
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check top_k = min(self.top_k, scores.shape[-1]) # Safety check
# Boolean mask containing all tokens with a probability less than the last token of the top-k # Boolean mask containing all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:] indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
next_scores = tf.where(indices_to_remove, self.filter_value, scores) next_scores = tf.where(indices_to_remove, self.filter_value, scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment