Unverified Commit 4df1d696 authored by HanHui's avatar HanHui Committed by GitHub
Browse files

[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201)



fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch
Co-authored-by: default avatarchenhanhui <chenhanhui@kanzhun.com>
parent 932ad8af
...@@ -2138,6 +2138,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): ...@@ -2138,6 +2138,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
scores = torch.where(do_early_stop, early_stop_scores, scores) scores = torch.where(do_early_stop, early_stop_scores, scores)
return scores return scores
...@@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase):
[float("-inf"), float("-inf"), scores[0][0], float("-inf")], [float("-inf"), float("-inf"), scores[0][0], float("-inf")],
] ]
self.assertListEqual(actual_scores.tolist(), expected_scores_list) self.assertListEqual(actual_scores.tolist(), expected_scores_list)
def test_early_stop_processor_multi_eos(self):
input_ids = None
eos_token_id = [2, 3]
min_eos_p = 0.1 ## some small float
scores = self._get_uniform_logits(2, 4)
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
actual_scores = esp(input_ids, scores)
expected_scores_list = [
scores[0].tolist(),
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment