Unverified Commit 3b00b623 authored by unifyh's avatar unifyh Committed by GitHub
Browse files

Fix `top_k_top_p_filtering` having unexpected behavior (#17744)

- Fix `top_k_top_p_filtering` not passing `filter_value` to
   `TopPLogitsWarper` causing any top-p filtered logits to be -inf
   instead of specified value

 - Add corresponding test
parent 3ccff0d4
...@@ -3347,6 +3347,8 @@ def top_k_top_p_filtering( ...@@ -3347,6 +3347,8 @@ def top_k_top_p_filtering(
) )
if 0 <= top_p <= 1.0: if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
return logits return logits
...@@ -1626,6 +1626,32 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1626,6 +1626,32 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
# tests whether the function uses filter_value instead of default -inf
def test_top_k_top_p_filtering_with_filter_value(self):
logits = torch.tensor(
[
[
1,
1,
1,
0.99, # get filtered by top-p filtering
0.98, # get filtered by top-k filtering
]
],
dtype=torch.float,
device=torch_device,
)
expected_output = torch.tensor(
[[1, 1, 1, 0, 0]],
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=4, top_p=0.5, filter_value=0.0)
self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))
@require_torch @require_torch
class GenerationIntegrationTests(unittest.TestCase): class GenerationIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment