"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4ab742459735671189d774cfa336d52561655816"
Unverified Commit f7196f2e authored by Santiago Castro's avatar Santiago Castro Committed by GitHub
Browse files

Fix decoding score comparison when using logits processors or warpers (#10638)

* Normalize using a logits warper

* Add a flag in `generate` to support the logit renormalization

* Add in RAG
parent eb5bdcdf
...@@ -679,3 +679,16 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): ...@@ -679,3 +679,16 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
self.regulation_factor, cur_len - self.regulation_start self.regulation_factor, cur_len - self.regulation_start
) )
return scores return scores
class LogitNormalization(LogitsProcessor, LogitsWarper):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
"""
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores.log_softmax(dim=-1)
return scores
...@@ -32,6 +32,7 @@ from .generation_logits_process import ( ...@@ -32,6 +32,7 @@ from .generation_logits_process import (
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
...@@ -636,6 +637,7 @@ class GenerationMixin: ...@@ -636,6 +637,7 @@ class GenerationMixin:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
num_beams: Optional[int] = None, num_beams: Optional[int] = None,
renormalize_logits: Optional[bool] = None,
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
...@@ -660,6 +662,9 @@ class GenerationMixin: ...@@ -660,6 +662,9 @@ class GenerationMixin:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0: if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers return warpers
def _get_logits_processor( def _get_logits_processor(
...@@ -682,6 +687,7 @@ class GenerationMixin: ...@@ -682,6 +687,7 @@ class GenerationMixin:
remove_invalid_values: bool, remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple, exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList], logits_processor: Optional[LogitsProcessorList],
renormalize_logits: Optional[bool],
) -> LogitsProcessorList: ) -> LogitsProcessorList:
""" """
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
...@@ -754,6 +760,9 @@ class GenerationMixin: ...@@ -754,6 +760,9 @@ class GenerationMixin:
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
) )
processors = self._merge_criteria_processor_list(processors, logits_processor) processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
processors.append(LogitNormalization())
return processors return processors
def _get_stopping_criteria( def _get_stopping_criteria(
...@@ -858,6 +867,7 @@ class GenerationMixin: ...@@ -858,6 +867,7 @@ class GenerationMixin:
diversity_penalty: Optional[float] = None, diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None, constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -986,6 +996,10 @@ class GenerationMixin: ...@@ -986,6 +996,10 @@ class GenerationMixin:
Custom logits processors that complement the default logits processors built from arguments and a Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown. This feature is intended for advanced users. config an error is thrown. This feature is intended for advanced users.
renormalize_logits: (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
stopping_criteria (`StoppingCriteriaList`, *optional*): stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a model's config. If a stopping criteria is passed that is already created with the arguments or a
...@@ -1241,6 +1255,7 @@ class GenerationMixin: ...@@ -1241,6 +1255,7 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty, exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor, logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
) )
# 8. prepare stopping criteria # 8. prepare stopping criteria
...@@ -1271,7 +1286,12 @@ class GenerationMixin: ...@@ -1271,7 +1286,12 @@ class GenerationMixin:
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 10. prepare logits warper # 10. prepare logits warper
logits_warper = self._get_logits_warper( logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
) )
# 11. expand input_ids with `num_return_sequences` additional sequences per batch # 11. expand input_ids with `num_return_sequences` additional sequences per batch
...@@ -1333,7 +1353,12 @@ class GenerationMixin: ...@@ -1333,7 +1353,12 @@ class GenerationMixin:
elif is_beam_sample_gen_mode: elif is_beam_sample_gen_mode:
# 10. prepare logits warper # 10. prepare logits warper
logits_warper = self._get_logits_warper( logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
) )
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
......
...@@ -1400,6 +1400,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1400,6 +1400,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs: Optional[int] = None, n_docs: Optional[int] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
forced_bos_token_id: Optional[int] = None, forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None,
...@@ -1624,6 +1625,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1624,6 +1625,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty, exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor, logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
) )
if num_beams == 1: if num_beams == 1:
......
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
...@@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase):
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id] scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all() ).all()
) )
def test_normalization(self):
input_ids = None
scores = torch.tensor(
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
)
logit_normalization = LogitNormalization()
normalized_scores = logit_normalization(input_ids, scores).exp()
ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment