"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d85bf954361f39f2ea38386940f40d29ed201910"
Unverified Commit 0b8604d0 authored by larekrow's avatar larekrow Committed by GitHub
Browse files

Update logits_process.py docstrings to clarify penalty and reward cases (attempt #2) (#26784)

* Update logits_process.py docstrings + match arg fields to __init__'s

* Ran `make style`
parent 85e9d644
......@@ -276,9 +276,14 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
Examples:
......@@ -313,7 +318,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
......@@ -322,11 +327,18 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
[`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original
input.
This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To
penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To
reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more
strongly.
Args:
hallucination_penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty.
penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between
0.0 and 1.0 rewards hallucination.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should be repeated within the decoder ids.
"""
......@@ -342,7 +354,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
# if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, self.encoder_input_ids, score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment