Unverified Commit 0efcf323 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Move `eos_token_id` to stopping criteria (#29459)



* add eos stopping criteria

* minor fix

* Update tests/generation/test_stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* check eos is not None and fix tests

* make style and fixup

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* camel case everywhere

* call stopping criteria list for candidate ids

* make style  and fixup

* Empty commit

* Empty commit to pass flaky test

* set max length in PromptLookupCandidateGenerator

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* lets fix this typo in docs

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update PR

* empty commit

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 31c575bc
...@@ -82,6 +82,7 @@ else: ...@@ -82,6 +82,7 @@ else:
"MaxNewTokensCriteria", "MaxNewTokensCriteria",
"MaxLengthCriteria", "MaxLengthCriteria",
"MaxTimeCriteria", "MaxTimeCriteria",
"EosTokenCriteria",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"validate_stopping_criteria", "validate_stopping_criteria",
...@@ -216,6 +217,7 @@ if TYPE_CHECKING: ...@@ -216,6 +217,7 @@ if TYPE_CHECKING:
WhisperTimeStampLogitsProcessor, WhisperTimeStampLogitsProcessor,
) )
from .stopping_criteria import ( from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria, MaxLengthCriteria,
MaxNewTokensCriteria, MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
......
...@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator): ...@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
The maximum ngram size to be considered for matching in the prompt The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`): num_output_tokens (`int`):
The number of tokens to be output as candidate tokens. The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
""" """
def __init__( def __init__(
self, self,
num_output_tokens: int = 10, num_output_tokens: int = 10,
max_matching_ngram_size: int = None, max_matching_ngram_size: int = None,
max_length: int = 20,
): ):
self.num_output_tokens = num_output_tokens self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_length = max_length
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
...@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator): ...@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
""" """
input_length = input_ids.size(1) input_length = input_ids.size(1)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
return input_ids, None
chosen_ids = None chosen_ids = None
match_found = False match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
...@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator): ...@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
for idx in match_indices: for idx in match_indices:
start_idx = idx + ngram_size start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length) end_idx = min(end_idx, input_length, self.max_length)
if start_idx < end_idx: if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx] chosen_ids = input_ids[0, start_idx:end_idx]
......
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
import warnings import warnings
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import List, Optional, Union
import torch import torch
...@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria): ...@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria):
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class EosTokenCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
By default, it uses the `model.generation_config.eos_token_id`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done
class StoppingCriteriaList(list): class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
......
This diff is collapsed.
...@@ -26,6 +26,7 @@ if is_torch_available(): ...@@ -26,6 +26,7 @@ if is_torch_available():
import torch import torch
from transformers.generation import ( from transformers.generation import (
EosTokenCriteria,
MaxLengthCriteria, MaxLengthCriteria,
MaxNewTokensCriteria, MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
...@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase):
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(all(criteria(input_ids, scores))) self.assertTrue(all(criteria(input_ids, scores)))
def test_eos_token_criteria(self):
criteria = EosTokenCriteria(eos_token_id=0)
input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 0
self.assertTrue(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(5)
input_ids[:2, -1] = 0
input_ids[2, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
def test_validate_stopping_criteria(self): def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
......
...@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
] ]
) )
last_assistant_token_is_eos = False last_assistant_token_is_eos = False
max_matches = 5
validated_tokens, n_matches = _speculative_sampling( validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids, candidate_input_ids,
candidate_logits, candidate_logits,
candidate_length, candidate_length,
new_logits, new_logits,
last_assistant_token_is_eos, last_assistant_token_is_eos,
max_matches,
) )
self.assertTrue(n_matches.item() == 2) self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment