Unverified Commit 0c41765d authored by Shauray Singh's avatar Shauray Singh Committed by GitHub
Browse files

[DOCS] Example for `LogitsProcessor` class (#24848)

* make docs

* fixup

* resolved

* remove debugs

* Revert "fixup"

This reverts commit 5e0f636aae0bf8707bc8bdaa6a9427fbf66834ed.

* prev (ignore)

* fixup broke some files

* remove files

* reverting modeling_reformer

* lang fix
parent 35c04596
......@@ -193,12 +193,38 @@ class TemperatureLogitsWarper(LogitsWarper):
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text
generation process, the probability distribution for the next token is determined using a formula that incorporates
token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
Examples:
```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> # Initializing the model and tokenizer for it
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
>>> # This shows a normal generate without any specific parameters
>>> summary_ids = model.generate(inputs["input_ids"], max_length=20)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
I'm not going to lie, I'm not going to lie. I'm not going to lie
>>> # This generates a penalty for repeated tokens
>>> penalized_ids = model.generate(inputs["input_ids"], max_length=20, repetition_penalty=1.2)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
I'm not going to lie, I was really excited about this. It's a great game
```
"""
def __init__(self, penalty: float):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment