Unverified Commit 0efcf323 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Move `eos_token_id` to stopping criteria (#29459)



* add eos stopping criteria

* minor fix

* Update tests/generation/test_stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* check eos is not None and fix tests

* make style and fixup

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/__init__.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* camel case everywhere

* call stopping criteria list for candidate ids

* make style  and fixup

* Empty commit

* Empty commit to pass flaky test

* set max length in PromptLookupCandidateGenerator

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* lets fix this typo in docs

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update PR

* empty commit

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 31c575bc
...@@ -82,6 +82,7 @@ else: ...@@ -82,6 +82,7 @@ else:
"MaxNewTokensCriteria", "MaxNewTokensCriteria",
"MaxLengthCriteria", "MaxLengthCriteria",
"MaxTimeCriteria", "MaxTimeCriteria",
"EosTokenCriteria",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"validate_stopping_criteria", "validate_stopping_criteria",
...@@ -216,6 +217,7 @@ if TYPE_CHECKING: ...@@ -216,6 +217,7 @@ if TYPE_CHECKING:
WhisperTimeStampLogitsProcessor, WhisperTimeStampLogitsProcessor,
) )
from .stopping_criteria import ( from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria, MaxLengthCriteria,
MaxNewTokensCriteria, MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
......
...@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator): ...@@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
The maximum ngram size to be considered for matching in the prompt The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`): num_output_tokens (`int`):
The number of tokens to be output as candidate tokens. The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
""" """
def __init__( def __init__(
self, self,
num_output_tokens: int = 10, num_output_tokens: int = 10,
max_matching_ngram_size: int = None, max_matching_ngram_size: int = None,
max_length: int = 20,
): ):
self.num_output_tokens = num_output_tokens self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_length = max_length
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
...@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator): ...@@ -264,6 +269,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
""" """
input_length = input_ids.size(1) input_length = input_ids.size(1)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
return input_ids, None
chosen_ids = None chosen_ids = None
match_found = False match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
...@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator): ...@@ -283,7 +292,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
for idx in match_indices: for idx in match_indices:
start_idx = idx + ngram_size start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length) end_idx = min(end_idx, input_length, self.max_length)
if start_idx < end_idx: if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx] chosen_ids = input_ids[0, start_idx:end_idx]
......
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
import warnings import warnings
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import List, Optional, Union
import torch import torch
...@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria): ...@@ -129,6 +129,27 @@ class MaxTimeCriteria(StoppingCriteria):
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class EosTokenCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
By default, it uses the `model.generation_config.eos_token_id`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done
class StoppingCriteriaList(list): class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
......
...@@ -75,6 +75,7 @@ from .logits_process import ( ...@@ -75,6 +75,7 @@ from .logits_process import (
UnbatchedClassifierFreeGuidanceLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor,
) )
from .stopping_criteria import ( from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria, MaxLengthCriteria,
MaxTimeCriteria, MaxTimeCriteria,
StoppingCriteria, StoppingCriteria,
...@@ -690,6 +691,7 @@ class GenerationMixin: ...@@ -690,6 +691,7 @@ class GenerationMixin:
candidate_generator = PromptLookupCandidateGenerator( candidate_generator = PromptLookupCandidateGenerator(
num_output_tokens=generation_config.prompt_lookup_num_tokens, num_output_tokens=generation_config.prompt_lookup_num_tokens,
max_matching_ngram_size=generation_config.max_matching_ngram_size, max_matching_ngram_size=generation_config.max_matching_ngram_size,
max_length=generation_config.max_length,
) )
else: else:
candidate_generator = AssistedCandidateGenerator( candidate_generator = AssistedCandidateGenerator(
...@@ -892,6 +894,8 @@ class GenerationMixin: ...@@ -892,6 +894,8 @@ class GenerationMixin:
) )
if generation_config.max_time is not None: if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
if generation_config.eos_token_id is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria return criteria
...@@ -1306,7 +1310,7 @@ class GenerationMixin: ...@@ -1306,7 +1310,7 @@ class GenerationMixin:
Return: Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
...@@ -1515,7 +1519,6 @@ class GenerationMixin: ...@@ -1515,7 +1519,6 @@ class GenerationMixin:
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1530,7 +1533,6 @@ class GenerationMixin: ...@@ -1530,7 +1533,6 @@ class GenerationMixin:
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1550,7 +1552,6 @@ class GenerationMixin: ...@@ -1550,7 +1552,6 @@ class GenerationMixin:
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1579,7 +1580,6 @@ class GenerationMixin: ...@@ -1579,7 +1580,6 @@ class GenerationMixin:
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1613,7 +1613,6 @@ class GenerationMixin: ...@@ -1613,7 +1613,6 @@ class GenerationMixin:
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1653,7 +1652,6 @@ class GenerationMixin: ...@@ -1653,7 +1652,6 @@ class GenerationMixin:
logits_warper=logits_warper, logits_warper=logits_warper,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1687,7 +1685,6 @@ class GenerationMixin: ...@@ -1687,7 +1685,6 @@ class GenerationMixin:
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1761,7 +1758,6 @@ class GenerationMixin: ...@@ -1761,7 +1758,6 @@ class GenerationMixin:
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits, output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
...@@ -1916,11 +1912,28 @@ class GenerationMixin: ...@@ -1916,11 +1912,28 @@ class GenerationMixin:
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
sequential = sequential if sequential is not None else self.generation_config.low_memory logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None sequential = sequential if sequential is not None else self.generation_config.low_memory
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = ( output_attentions = (
...@@ -2186,12 +2199,6 @@ class GenerationMixin: ...@@ -2186,12 +2199,6 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished # stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0 this_peer_finished = unfinished_sequences.max() == 0
...@@ -2365,10 +2372,27 @@ class GenerationMixin: ...@@ -2365,10 +2372,27 @@ class GenerationMixin:
) )
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = ( output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions output_attentions if output_attentions is not None else self.generation_config.output_attentions
...@@ -2463,12 +2487,6 @@ class GenerationMixin: ...@@ -2463,12 +2487,6 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0 this_peer_finished = unfinished_sequences.max() == 0
...@@ -2650,10 +2668,27 @@ class GenerationMixin: ...@@ -2650,10 +2668,27 @@ class GenerationMixin:
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = ( output_attentions = (
...@@ -2751,12 +2786,6 @@ class GenerationMixin: ...@@ -2751,12 +2786,6 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0 this_peer_finished = unfinished_sequences.max() == 0
...@@ -2966,7 +2995,25 @@ class GenerationMixin: ...@@ -2966,7 +2995,25 @@ class GenerationMixin:
if len(stopping_criteria) == 0: if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
...@@ -3351,7 +3398,25 @@ class GenerationMixin: ...@@ -3351,7 +3398,25 @@ class GenerationMixin:
) )
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
...@@ -3688,7 +3753,25 @@ class GenerationMixin: ...@@ -3688,7 +3753,25 @@ class GenerationMixin:
) )
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
...@@ -4089,7 +4172,25 @@ class GenerationMixin: ...@@ -4089,7 +4172,25 @@ class GenerationMixin:
if len(stopping_criteria) == 0: if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
logger.warning_once(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
...@@ -4421,12 +4522,27 @@ class GenerationMixin: ...@@ -4421,12 +4522,27 @@ class GenerationMixin:
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if eos_token_id is not None:
if eos_token_id is not None and pad_token_id is None: logger.warning_once(
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") "`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
" Otherwise make sure to set `model.generation_config.eos_token_id`",
FutureWarning,
)
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = ( output_attentions = (
...@@ -4462,9 +4578,6 @@ class GenerationMixin: ...@@ -4462,9 +4578,6 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
# other auxiliary variables
max_len = stopping_criteria[0].max_length
this_peer_finished = False this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
...@@ -4476,13 +4589,7 @@ class GenerationMixin: ...@@ -4476,13 +4589,7 @@ class GenerationMixin:
candidate_logits = candidate_logits.to(self.device) candidate_logits = candidate_logits.to(self.device)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = ( is_done_candidate = stopping_criteria(candidate_input_ids, None)
~candidate_input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
.bool()
)
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
...@@ -4525,15 +4632,13 @@ class GenerationMixin: ...@@ -4525,15 +4632,13 @@ class GenerationMixin:
# 3. Select the accepted tokens. There are two possible cases: # 3. Select the accepted tokens. There are two possible cases:
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
max_matches = max_len - cur_len - 1
if do_sample and candidate_logits is not None: if do_sample and candidate_logits is not None:
valid_tokens, n_matches = _speculative_sampling( valid_tokens, n_matches = _speculative_sampling(
candidate_input_ids, candidate_input_ids,
candidate_logits, candidate_logits,
candidate_length, candidate_length,
new_logits, new_logits,
last_assistant_token_is_eos, is_done_candidate,
max_matches,
) )
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
...@@ -4550,9 +4655,8 @@ class GenerationMixin: ...@@ -4550,9 +4655,8 @@ class GenerationMixin:
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# Ensure we don't generate beyond max_len or an EOS token # Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length: if is_done_candidate and n_matches == candidate_length:
n_matches -= 1 n_matches -= 1
n_matches = min(n_matches, max_matches)
valid_tokens = selected_tokens[:, : n_matches + 1] valid_tokens = selected_tokens[:, : n_matches + 1]
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
...@@ -4625,15 +4729,6 @@ class GenerationMixin: ...@@ -4625,15 +4729,6 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0 this_peer_finished = unfinished_sequences.max() == 0
...@@ -4678,8 +4773,7 @@ def _speculative_sampling( ...@@ -4678,8 +4773,7 @@ def _speculative_sampling(
candidate_logits, candidate_logits,
candidate_length, candidate_length,
new_logits, new_logits,
last_assistant_token_is_eos, is_done_candidate,
max_matches,
): ):
""" """
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
...@@ -4704,16 +4798,14 @@ def _speculative_sampling( ...@@ -4704,16 +4798,14 @@ def _speculative_sampling(
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if last_assistant_token_is_eos and n_matches == candidate_length: if is_done_candidate and n_matches == candidate_length:
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
# due to acceptance on EOS we fix `n_matches` # due to acceptance on EOS we fix `n_matches`
n_matches -= 1 n_matches -= 1
valid_tokens = new_candidate_input_ids[:, : n_matches + 1] valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
else: else:
n_matches = min(n_matches, max_matches)
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = min(candidate_logits.shape[1], max_matches) gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :] p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma: if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :] q_n_plus_1 = q[:, n_matches, :]
......
...@@ -26,6 +26,7 @@ if is_torch_available(): ...@@ -26,6 +26,7 @@ if is_torch_available():
import torch import torch
from transformers.generation import ( from transformers.generation import (
EosTokenCriteria,
MaxLengthCriteria, MaxLengthCriteria,
MaxNewTokensCriteria, MaxNewTokensCriteria,
MaxTimeCriteria, MaxTimeCriteria,
...@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase): ...@@ -98,6 +99,22 @@ class StoppingCriteriaTestCase(unittest.TestCase):
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(all(criteria(input_ids, scores))) self.assertTrue(all(criteria(input_ids, scores)))
def test_eos_token_criteria(self):
criteria = EosTokenCriteria(eos_token_id=0)
input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 0
self.assertTrue(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(5)
input_ids[:2, -1] = 0
input_ids[2, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
def test_validate_stopping_criteria(self): def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
......
...@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1899,14 +1899,12 @@ class UtilsFunctionsTest(unittest.TestCase):
] ]
) )
last_assistant_token_is_eos = False last_assistant_token_is_eos = False
max_matches = 5
validated_tokens, n_matches = _speculative_sampling( validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids, candidate_input_ids,
candidate_logits, candidate_logits,
candidate_length, candidate_length,
new_logits, new_logits,
last_assistant_token_is_eos, last_assistant_token_is_eos,
max_matches,
) )
self.assertTrue(n_matches.item() == 2) self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment