Unverified Commit 9960506c authored by Motoki Wu's avatar Motoki Wu Committed by GitHub
Browse files

Fix multiple `eos_token_id`s in model.generate(...) (#21461)

* add tests with multiple eos_token_ids

* make math.prod instead of sum

* make fixup

* fix long and also use np.prod since math.prod does not exist <python 3.8

* make fixup

* add prod util

* use prod util instead of np.prod

* make fixup

* previous .long location

* use tensor ops

* remove prod

* remove prod

* update device

* make fixup

* fix none
parent 06d940ef
......@@ -1757,6 +1757,7 @@ class GenerationMixin:
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -1980,8 +1981,10 @@ class GenerationMixin:
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
if eos_token_id_tensor is not None:
unfinished_sequences = (
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
......@@ -2129,6 +2132,7 @@ class GenerationMixin:
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -2223,8 +2227,10 @@ class GenerationMixin:
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
if eos_token_id_tensor is not None:
unfinished_sequences = (
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
......@@ -2392,6 +2398,7 @@ class GenerationMixin:
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
......@@ -2489,8 +2496,10 @@ class GenerationMixin:
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
if eos_token_id_tensor is not None:
unfinished_sequences = (
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
......
......@@ -2609,7 +2609,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
......@@ -2634,7 +2634,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [225]
eos_token_id = [225, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
......@@ -2660,7 +2660,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [846]
eos_token_id = [846, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
......@@ -2683,7 +2683,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment