Unverified Commit b9403e95 authored by Karim Foda's avatar Karim Foda Committed by GitHub
Browse files

Add hallucination filter (#18675)



* Add hallucination penalty

* Make quality changes

* Inverse penalty

* Fix imports & quality

* Fix name spelling issue

* set encoder_repetition_penalty and fix quality

* Fix failing test

* Add to config_common_kwargs

* Fix modelling_rag error

* Update src/transformers/generation_logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Remove breakpoint

* Make style fixes

* Update encoder_repetition_penalty default value

* Merge latest main changes

* Make fixup changes

* Add EncoderRepetitionPenaltyLogitsProcessor to generation/__init__.py

* Fix repo-inconsistency

* Remove venv

* Remove tensorflow-macos & add tests

* Add documentation

* Fix quality issues

* move encoder_repetition_penalty to config

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Remove encoder_repetition_penalty from tests

* Fix type error

* Fix format error
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent e9b4800d
...@@ -58,6 +58,7 @@ else: ...@@ -58,6 +58,7 @@ else:
"NoRepeatNGramLogitsProcessor", "NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor", "PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor", "RepetitionPenaltyLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
...@@ -164,6 +165,7 @@ if TYPE_CHECKING: ...@@ -164,6 +165,7 @@ if TYPE_CHECKING:
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import ( from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
......
...@@ -127,6 +127,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -127,6 +127,9 @@ class GenerationConfig(PushToHubMixin):
repetition_penalty (`float`, *optional*, defaults to 1.0): repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the
original input. 1.0 means no penalty.
length_penalty (`float`, *optional*, defaults to 1.0): length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
...@@ -239,6 +242,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -239,6 +242,7 @@ class GenerationConfig(PushToHubMixin):
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0) self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.bad_words_ids = kwargs.pop("bad_words_ids", None)
......
...@@ -204,6 +204,34 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -204,6 +204,34 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
return scores return scores
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
Args:
hallucination_penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should not be repeated within the decoder ids.
"""
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = 1 / penalty
self.encoder_input_ids = encoder_input_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, self.encoder_input_ids, score)
return scores
class TopPLogitsWarper(LogitsWarper): class TopPLogitsWarper(LogitsWarper):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
......
...@@ -39,6 +39,7 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor ...@@ -39,6 +39,7 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor
from .configuration_utils import GenerationConfig from .configuration_utils import GenerationConfig
from .logits_process import ( from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
...@@ -799,6 +800,15 @@ class GenerationMixin: ...@@ -799,6 +800,15 @@ class GenerationMixin:
num_beam_groups=generation_config.num_beam_groups, num_beam_groups=generation_config.num_beam_groups,
) )
) )
if (
generation_config.encoder_repetition_penalty is not None
and generation_config.encoder_repetition_penalty != 1.0
):
processors.append(
EncoderRepetitionPenaltyLogitsProcessor(
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
)
)
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
......
...@@ -28,6 +28,7 @@ if is_torch_available(): ...@@ -28,6 +28,7 @@ if is_torch_available():
from transformers.generation import ( from transformers.generation import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
...@@ -175,6 +176,31 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -175,6 +176,31 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
def test_encoder_repetition_penalty_dist_process(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids)
scores = rep_penalty_proc(input_ids, scores.clone())
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) * 2)
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) * 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) * 2)
# check that values not in the encoder ids were NOT changed
self.assertAlmostEqual(scores[0, 2].item(), (1 / vocab_size))
self.assertAlmostEqual(scores[1, 2].item(), (1 / vocab_size))
def test_top_k_dist_warper(self): def test_top_k_dist_warper(self):
input_ids = None input_ids = None
vocab_size = 10 vocab_size = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment