Unverified Commit 0b8604d0 authored by larekrow's avatar larekrow Committed by GitHub
Browse files

Update logits_process.py docstrings to clarify penalty and reward cases (attempt #2) (#26784)

* Update logits_process.py docstrings + match arg fields to __init__'s

* Ran `make style`
parent 85e9d644
...@@ -276,9 +276,14 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -276,9 +276,14 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition. paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args: Args:
repetition_penalty (`float`): penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
Examples: Examples:
...@@ -313,7 +318,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -313,7 +318,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty) score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score) scores.scatter_(1, input_ids, score)
...@@ -322,11 +327,18 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -322,11 +327,18 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input. [`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original
input.
This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To
penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To
reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more
strongly.
Args: Args:
hallucination_penalty (`float`): penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty. The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between
0.0 and 1.0 rewards hallucination.
encoder_input_ids (`torch.LongTensor`): encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should be repeated within the decoder ids. The encoder_input_ids that should be repeated within the decoder ids.
""" """
...@@ -342,7 +354,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -342,7 +354,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids) score = torch.gather(scores, 1, self.encoder_input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty) score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, self.encoder_input_ids, score) scores.scatter_(1, self.encoder_input_ids, score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment