Unverified Commit 0efcf323 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Move `eos_token_id` to stopping criteria (#29459)



* add eos stopping criteria

* minor fix

* Update tests/generation/test_stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* check eos is not None and fix tests

* make style and fixup

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* camel case everywhere

* call stopping criteria list for candidate ids

* make style  and fixup

* Empty commit

* Empty commit to pass flaky test

* set max length in PromptLookupCandidateGenerator

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* lets fix this typo in docs

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update PR

* empty commit

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 31c575bc
......@@ -82,6 +82,7 @@ else:
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"EosTokenCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
......@@ -216,6 +217,7 @@ if TYPE_CHECKING:
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
......
......@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
"""
def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
max_length: int = 20,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_length = max_length
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
......@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
"""
input_length = input_ids.size(1)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
return input_ids, None
chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
......@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)
end_idx = min(end_idx, input_length, self.max_length)
if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
......
......@@ -2,7 +2,7 @@ import time
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional
from typing import List, Optional, Union
import torch
......@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria):
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class EosTokenCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
By default, it uses the `model.generation_config.eos_token_id`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
......
This diff is collapsed.
......@@ -26,6 +26,7 @@ if is_torch_available():
import torch
from transformers.generation import (
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
......@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase):
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(all(criteria(input_ids, scores)))
def test_eos_token_criteria(self):
criteria = EosTokenCriteria(eos_token_id=0)
input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 0
self.assertTrue(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(5)
input_ids[:2, -1] = 0
input_ids[2, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
......
......@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
]
)
last_assistant_token_is_eos = False
max_matches = 5
validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
)
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment