Unverified Commit eec0d84e authored by Ashish Thomas Chempolil's avatar Ashish Thomas Chempolil Committed by GitHub
Browse files

[DOCS] Add example and modified docs of EtaLogitsWarper (#25125)



* added example and modified docs for EtaLogitsWarper

* make style

* fixed styling issue on 544

* removed error info and added set_seed

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* updated the results

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 8021c684
......@@ -485,14 +485,65 @@ class EpsilonLogitsWarper(LogitsWarper):
class EtaLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs eta-sampling, i.e. calculates a dynamic cutoff `eta := min(epsilon, sqrt(epsilon,
e^-entropy(probabilities)))` and restricts to tokens with `prob >= eta`. Takes the largest min_tokens_to_keep
tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon, e^-entropy(probabilities)))`. Takes the largest
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
must be set to `True` for this `LogitsWarper` to work.
Args:
epsilon (`float`):
A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The
suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This
parameter is useful when logits need to be modified for very low probability tokens that should be excluded
from generation entirely.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered."""
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`.
Examples:
```python
>>> # Import required libraries
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> # Set the model name
>>> model_name = "gpt2"
>>> # Initialize the model and tokenizer
>>> model = AutoModelForCausalLM.from_pretrained(model_name)
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> # Set the pad token to eos token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.generation_config.pad_token_id = model.config.eos_token_id
>>> # The below sequence intentionally contains two subjects to show the difference between the two approaches
>>> sequence = "a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding things like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day. . . disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered"
>>> # Tokenize the sequence
>>> inputs = tokenizer(sequence, return_tensors="pt")
>>> set_seed(0)
>>> # We can see that the model is generating repeating text and also is not able to continue the sequence properly
>>> outputs = model.generate(inputs["input_ids"], max_length=128)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding things like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day... disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered muscle mass. The patient was diagnosed with a severe erythema and a severe erythema-like condition. The patient was treated with a combination
>>> # The result is much better and coherent when we use the `eta_cutoff` parameter
>>> outputs = model.generate(
... inputs["input_ids"], max_length=128, do_sample=True, eta_cutoff=2e-2
... ) # need to set do_sample=True for eta_cutoff to work
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding things like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day... disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered fatty acids. A significant loss of brain development. The individual also experienced high levels of a common psychiatric condition called schizophrenia, with an important and life threatening consequence.
```
"""
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
epsilon = float(epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment