Unverified Commit 578e18e0 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)

* init PR

* optimize top p and add edge case

* styling

* style

* revert tf and flax test

* add edge case test for FLAX and TF

* update doc with smallest set sampling for top p

* make style
parent 2700ba66
...@@ -118,8 +118,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): ...@@ -118,8 +118,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
Args: Args:
top_p (`float`): top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
for generation. higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`): filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value. All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1): min_tokens_to_keep (`int`, *optional*, defaults to 1):
......
...@@ -173,8 +173,8 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -173,8 +173,8 @@ class TopPLogitsWarper(LogitsWarper):
Args: Args:
top_p (`float`): top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
for generation. higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`): filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value. All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1): min_tokens_to_keep (`int`, *optional*, defaults to 1):
...@@ -191,17 +191,14 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -191,17 +191,14 @@ class TopPLogitsWarper(LogitsWarper):
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True) sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
......
...@@ -150,8 +150,8 @@ class TFTopPLogitsWarper(TFLogitsWarper): ...@@ -150,8 +150,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
Args: Args:
top_p (`float`): top_p (`float`):
If set to < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
for generation. higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`): filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value. All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1): min_tokens_to_keep (`int`, *optional*, defaults to 1):
......
...@@ -990,8 +990,8 @@ class GenerationMixin: ...@@ -990,8 +990,8 @@ class GenerationMixin:
top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value): top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
The number of highest probability vocabulary tokens to keep for top-k-filtering. The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value): top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
are kept for generation. `top_p` or higher are kept for generation.
typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value): typical_p (`float`, *optional*, defaults to `model.config.typical_p` or 1.0 if the config does not set any value):
The amount of probability mass from the original distribution to be considered in typical decoding. If The amount of probability mass from the original distribution to be considered in typical decoding. If
set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. set to 1.0 it takes no effect. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
......
...@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -110,10 +110,10 @@ class LogitsProcessorTest(unittest.TestCase):
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]])) dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7) top_p_warp = FlaxTopPLogitsWarper(0.8)
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None)) filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
# dist should be filtered to keep min num values so that sum is >= 0.7 # dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0 # exp (-inf) => 0
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]]) EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
......
...@@ -169,10 +169,10 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -169,10 +169,10 @@ class LogitsProcessorTest(unittest.TestCase):
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float) torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
) )
top_p_warp = TopPLogitsWarper(0.7) top_p_warp = TopPLogitsWarper(0.8)
filtered_dist = torch.exp(top_p_warp(input_ids, dist)) filtered_dist = torch.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7 # dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0 # exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor( EXPECTED_FILTERED_DIST = torch.tensor(
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float [[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
......
...@@ -189,12 +189,15 @@ class TFLogitsProcessorTest(unittest.TestCase): ...@@ -189,12 +189,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper) # create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32)) dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
top_p_warp = TFTopPLogitsWarper(0.7) # top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob
# However, due to the numerical instability of softmax in TF we choose this as the edge case
# top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984.
top_p_warp = TFTopPLogitsWarper(0.79999995)
if use_xla: if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True) top_p_warp = tf.function(top_p_warp, jit_compile=True)
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len)) filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
# dist should be filtered to keep min num values so that sum is >= 0.7 # dist should be filtered to keep min num values so that sum is >= top_p
# exp (-inf) => 0 # exp (-inf) => 0
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32) EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3) tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment