Unverified Commit 58e7f9bb authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: All logits processors are documented and have examples (#27796)


Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 47500b1d
...@@ -250,7 +250,7 @@ While the autoregressive generation process is relatively straightforward, makin ...@@ -250,7 +250,7 @@ While the autoregressive generation process is relatively straightforward, makin
1. [Guide](generation_strategies) on how to control different generation methods, how to set up the generation configuration file, and how to stream the output; 1. [Guide](generation_strategies) on how to control different generation methods, how to set up the generation configuration file, and how to stream the output;
2. [Guide](chat_templating) on the prompt template for chat LLMs; 2. [Guide](chat_templating) on the prompt template for chat LLMs;
3. [Guide](tasks/prompting) on to get the most of prompt design; 3. [Guide](tasks/prompting) on to get the most of prompt design;
4. API reference on [`~generation.GenerationConfig`], [`~generation.GenerationMixin.generate`], and [generate-related classes](internal/generation_utils). 4. API reference on [`~generation.GenerationConfig`], [`~generation.GenerationMixin.generate`], and [generate-related classes](internal/generation_utils). Most of the classes, including the logits processors, have usage examples!
### LLM leaderboards ### LLM leaderboards
......
...@@ -63,6 +63,14 @@ class GenerationConfig(PushToHubMixin): ...@@ -63,6 +63,14 @@ class GenerationConfig(PushToHubMixin):
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
<Tip>
A large number of these flags control the logits or the stopping criteria of the generation. Make sure you check
the [generate-related classes](https://huggingface.co/docs/transformers/internal/generation_utils) for a full
description of the possible manipulations, as well as examples of their usage.
</Tip>
Arg: Arg:
> Parameters that control the length of the output > Parameters that control the length of the output
......
...@@ -100,13 +100,39 @@ class LogitsProcessorList(list): ...@@ -100,13 +100,39 @@ class LogitsProcessorList(list):
class MinLengthLogitsProcessor(LogitsProcessor): class MinLengthLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models
like most LLMs, the length includes the prompt.
Args: Args:
min_length (`int`): min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`): eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("A number:", return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact
>>> gen_out = model.generate(**inputs, min_length=3)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one
>>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not
>>> # necessarily incorrect
>>> gen_out = model.generate(**inputs, min_length=10)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
A number: one thousand, nine hundred and ninety-four
```
""" """
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
...@@ -133,9 +159,7 @@ class MinLengthLogitsProcessor(LogitsProcessor): ...@@ -133,9 +159,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
class MinNewTokensLengthLogitsProcessor(LogitsProcessor): class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Note that for decoder-only models, such as Llama2, `min_length` will compute the length of `prompt + newly Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt.
generated tokens` whereas for other models it will behave as `min_new_tokens`, that is, taking only into account
the newly generated ones.
Args: Args:
prompt_length_to_skip (`int`): prompt_length_to_skip (`int`):
...@@ -149,29 +173,21 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -149,29 +173,21 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
Examples: Examples:
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> model.config.pad_token_id = model.config.eos_token_id
>>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="pt")
>>> # If the maximum length (default = 20) is smaller than the minimum length constraint, the latter is ignored! >>> inputs = tokenizer(["A number:"], return_tensors="pt")
>>> outputs = model.generate(**inputs, min_new_tokens=30) >>> gen_out = model.generate(**inputs)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
Hugging Face Company is a company that has been working on a new product for the past year. A number: one
>>> # For testing purposes, let's set `eos_token` to `"company"`, the first generated token. This will make >>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not
>>> # generation end there. >>> # necessarily incorrect
>>> outputs = model.generate(**inputs, eos_token_id=1664) >>> gen_out = model.generate(**inputs, min_new_tokens=2)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
Hugging Face Company is a company A number: one thousand
>>> # Increasing `min_new_tokens` will make generation ignore occurences `"company"` (eos token) before the
>>> # minimum length condition is honored.
>>> outputs = model.generate(**inputs, min_new_tokens=2, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a new company
``` ```
""" """
...@@ -205,7 +221,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): ...@@ -205,7 +221,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
class TemperatureLogitsWarper(LogitsWarper): class TemperatureLogitsWarper(LogitsWarper):
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
that it can control the randomness of the predicted tokens. that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
[`TopKLogitsWarper`].
<Tip> <Tip>
...@@ -269,22 +286,18 @@ class TemperatureLogitsWarper(LogitsWarper): ...@@ -269,22 +286,18 @@ class TemperatureLogitsWarper(LogitsWarper):
class RepetitionPenaltyLogitsProcessor(LogitsProcessor): class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through an exponential penalty. This technique [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
shares some similarities with coverage mechanisms and other aimed at reducing repetition. During the text most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
generation process, the probability distribution for the next token is determined using a formula that incorporates
token scores based on their occurrence in the generated sequence. Tokens with higher scores are more likely to be In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly. repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args: Args:
penalty (`float`): penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this tokens. Between 0.0 and 1.0 rewards previously generated tokens.
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
Examples: Examples:
...@@ -327,20 +340,39 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -327,20 +340,39 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original [`LogitsProcessor`] that works similarly to [`RepetitionPenaltyLogitsProcessor`], but with an *inverse* penalty
input. that is applied to the tokens present in the prompt. In other words, a penalty above 1.0 increases the odds of
selecting tokens that were present in the prompt.
This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To It was designed to avoid hallucination in input-grounded tasks, like summarization. Although originally intended
penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To for encoder-decoder models, it can also be used with decoder-only models like LLMs.
reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more
strongly.
Args: Args:
penalty (`float`): penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 rewards prompt tokens. Between 0.0
0.0 and 1.0 rewards hallucination. and 1.0 penalizes prompt tokens.
encoder_input_ids (`torch.LongTensor`): encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should be repeated within the decoder ids. The encoder_input_ids that should be repeated within the decoder ids.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer(["Alice and Bob. The third member's name was"], return_tensors="pt")
>>> gen_out = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
Alice and Bob. The third member's name was not mentioned.
>>> # With the `encoder_repetition_penalty` argument we can trigger this logits processor in `generate`, which can
>>> # promote the use of prompt tokens ("Bob" in this example)
>>> gen_out = model.generate(**inputs, encoder_repetition_penalty=1.2)
>>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
Alice and Bob. The third member's name was Bob. The third member's name was Bob.
```
""" """
def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor): def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
...@@ -363,7 +395,8 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -363,7 +395,8 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
class TopPLogitsWarper(LogitsWarper): class TopPLogitsWarper(LogitsWarper):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
Args: Args:
top_p (`float`): top_p (`float`):
...@@ -375,6 +408,7 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -375,6 +408,7 @@ class TopPLogitsWarper(LogitsWarper):
Minimum number of tokens that cannot be filtered. Minimum number of tokens that cannot be filtered.
Examples: Examples:
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
...@@ -426,7 +460,8 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -426,7 +460,8 @@ class TopPLogitsWarper(LogitsWarper):
class TopKLogitsWarper(LogitsWarper): class TopKLogitsWarper(LogitsWarper):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
Args: Args:
top_k (`int`): top_k (`int`):
...@@ -435,6 +470,29 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -435,6 +470,29 @@ class TopKLogitsWarper(LogitsWarper):
All filtered values will be set to this float value. All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1): min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered. Minimum number of tokens that cannot be filtered.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0)
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="pt")
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, G, H, I. A, M
>>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
>>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
>>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: A, B, C, D, E, F, G, H, I
```
""" """
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
...@@ -455,8 +513,11 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -455,8 +513,11 @@ class TopKLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(LogitsWarper): class TypicalLogitsWarper(LogitsWarper):
r""" r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language [`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
Generation](https://arxiv.org/abs/2202.00666) for more information. log probability is close to the entropy of the token probability distribution. This means that the most likely
tokens may be discarded in the process.
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
Args: Args:
mass (`float`, *optional*, defaults to 0.9): mass (`float`, *optional*, defaults to 0.9):
...@@ -465,6 +526,42 @@ class TypicalLogitsWarper(LogitsWarper): ...@@ -465,6 +526,42 @@ class TypicalLogitsWarper(LogitsWarper):
All filtered values will be set to this float value. All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1): min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered. Minimum number of tokens that cannot be filtered.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("1, 2, 3", return_tensors="pt")
>>> # We can see that greedy decoding produces a sequence of numbers
>>> outputs = model.generate(**inputs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
>>> # For this particular seed, we can see that sampling produces nearly the same low-information (= low entropy)
>>> # sequence
>>> set_seed(18)
>>> outputs = model.generate(**inputs, do_sample=True)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
1, 2, 3, 4, 5, 6, 7, 8, 9 and 10
>>> # With `typical_p` set, the most obvious sequence is no longer produced, which may be good for your problem
>>> set_seed(18)
>>> outputs = model.generate(
... **inputs, do_sample=True, typical_p=0.1, return_dict_in_generate=True, output_scores=True
... )
>>> print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0])
1, 2, 3 and 5
>>> # We can see that the token corresponding to "4" (token 934) in the second position, the most likely token
>>> # as seen with greedy decoding, was entirely blocked out
>>> print(outputs.scores[1][0, 934])
tensor(-inf)
```
""" """
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
...@@ -721,7 +818,8 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -721,7 +818,8 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation, sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
from consideration when further processing the scores. from consideration when further processing the scores. Note that, for decoder-only models like most LLMs, the
prompt is also considered to obtain the n-grams.
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
<Tip> <Tip>
...@@ -774,14 +872,40 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -774,14 +872,40 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. See [`LogitsProcessor`] that works similarly to [`NoRepeatNGramLogitsProcessor`], but applied exclusively to prevent
[ParlAI](https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350). the repetition of n-grams present in the prompt.
It was designed to promote chattiness in a language model, by preventing the generation of n-grams present in
previous conversation rounds.
Args: Args:
encoder_ngram_size (`int`): encoder_ngram_size (`int`):
All ngrams of size `ngram_size` can only occur within the encoder input ids. All ngrams of size `ngram_size` can only occur within the encoder input ids.
encoder_input_ids (`int`): encoder_input_ids (`int`):
The encoder_input_ids that should not be repeated within the decoder ids. The encoder_input_ids that should not be repeated within the decoder ids.
Examples:
```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("Alice: I love cats. What do you love?\nBob:", return_tensors="pt")
>>> # With greedy decoding, we see Bob repeating Alice's opinion. If Bob was a chatbot, it would be a poor one.
>>> outputs = model.generate(**inputs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice: I love cats. What do you love?
Bob: I love cats. What do you
>>> # With this logits processor, we can prevent Bob from repeating Alice's opinion.
>>> outputs = model.generate(**inputs, encoder_no_repeat_ngram_size=2)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice: I love cats. What do you love?
Bob: My cats are very cute.
```
""" """
def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor): def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
...@@ -1060,6 +1184,40 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): ...@@ -1060,6 +1184,40 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
`batch_id`. `batch_id`.
Examples:
```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
>>> inputs = tokenizer("Alice and Bob", return_tensors="pt")
>>> # By default, it continues generating according to the model's logits
>>> outputs = model.generate(**inputs, max_new_tokens=5)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice and Bob are friends
>>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
>>> # For instance, we can force an entire entity to be generated when its beginning is detected.
>>> entity = tokenizer(" Bob Marley", return_tensors="pt").input_ids[0] # 3 tokens
>>> def prefix_allowed_tokens_fn(batch_id, input_ids):
... '''
... Attempts to generate 'Bob Marley' when 'Bob' is detected.
... In this case, `batch_id` is not used, but you can set rules for each batch member.
... '''
... if input_ids[-1] == entity[0]:
... return entity[1]
... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
... return entity[2]
... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
>>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Alice and Bob Marley
```
""" """
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
...@@ -1084,56 +1242,20 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): ...@@ -1084,56 +1242,20 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more
details. details.
<Tip>
Diverse beam search can be particularly useful in scenarios where a variety of different outputs is desired, rather
than multiple similar sequences. It allows the model to explore different generation paths and provides a broader
coverage of possible outputs.
</Tip>
<Tip warning={true}>
This logits processor can be resource-intensive, especially when using large models or long sequences.
</Tip>
Traditional beam search often generates very similar sequences across different beams. Traditional beam search often generates very similar sequences across different beams.
`HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other `HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
beams in the same time step. beams in the same time step.
How It Works:
- **Grouping Beams**: Beams are divided into groups. Each group selects tokens independently of the others.
- **Penalizing Repeated Tokens**: If a beam in a group selects a token already chosen by another group in the
same step, a penalty is applied to that token's score.
- **Promoting Diversity**: This penalty discourages beams within a group from selecting the same tokens as
beams in other groups.
Benefits:
- **Diverse Outputs**: Produces a variety of different sequences.
- **Exploration**: Allows the model to explore different paths.
Args: Args:
diversity_penalty (`float`): diversity_penalty (`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that `diversity_penalty` is only effective if group beam search is enabled. The particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
penalty applied to a beam's score when it generates a token that has already been chosen by another beam this value can help strike a balance between diversity and natural likelihood.
within the same group during the same time step. A higher `diversity_penalty` will enforce greater
diversity among the beams, making it less likely for multiple beams to choose the same token. Conversely, a
lower penalty will allow beams to more freely choose similar tokens. Adjusting this value can help strike a
balance between diversity and natural likelihood.
num_beams (`int`): num_beams (`int`):
Number of beams used for group beam search. Beam search is a method used that maintains beams (or "multiple Number of beams for beam search. 1 means no beam search.
hypotheses") at each step, expanding each one and keeping the top-scoring sequences. A higher `num_beams`
will explore more potential sequences. This can increase chances of finding a high-quality output but also
increases computational cost.
num_beam_groups (`int`): num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
Each group of beams will operate independently, selecting tokens without considering the choices of other [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
groups. This division promotes diversity by ensuring that beams within different groups explore different
paths. For instance, if `num_beams` is 6 and `num_beam_groups` is 2, there will be 2 groups each containing
3 beams. The choice of `num_beam_groups` should be made considering the desired level of output diversity
and the total number of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
Examples: Examples:
...@@ -1146,7 +1268,13 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): ...@@ -1146,7 +1268,13 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> # A long text about the solar system >>> # A long text about the solar system
>>> text = "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant interstellar molecular cloud." >>> text = (
... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
... "interstellar molecular cloud."
... )
>>> inputs = tokenizer("summarize: " + text, return_tensors="pt") >>> inputs = tokenizer("summarize: " + text, return_tensors="pt")
>>> # Generate diverse summary >>> # Generate diverse summary
...@@ -1241,11 +1369,34 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): ...@@ -1241,11 +1369,34 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
class ForcedBOSTokenLogitsProcessor(LogitsProcessor): class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that enforces the specified token as the first generated token. [`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder
models.
Args: Args:
bos_token_id (`int`): bos_token_id (`int`):
The id of the token to force as the first generated token. The id of the token to force as the first generated token.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
>>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
>>> inputs = tokenizer("Translate from English to German: I love cats.", return_tensors="pt")
>>> # By default, it continues generating according to the model's logits
>>> outputs = model.generate(**inputs, max_new_tokens=10)
>>> print(tokenizer.batch_decode(outputs)[0])
<pad> Ich liebe Kitty.</s>
>>> # We can use `forced_bos_token_id` to force the start of generation with an encoder-decoder model
>>> # (including forcing it to end straight away with an EOS token)
>>> outputs = model.generate(**inputs, max_new_tokens=10, forced_bos_token_id=tokenizer.eos_token_id)
>>> print(tokenizer.batch_decode(outputs)[0])
<pad></s>
```
""" """
def __init__(self, bos_token_id: int): def __init__(self, bos_token_id: int):
...@@ -1271,6 +1422,27 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): ...@@ -1271,6 +1422,27 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
eos_token_id (`Union[int, List[int]]`): eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens. list to set multiple *end-of-sequence* tokens.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
>>> # By default, it continues generating according to the model's logits
>>> outputs = model.generate(**inputs, max_new_tokens=10)
>>> print(tokenizer.batch_decode(outputs)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8
>>> # `forced_eos_token_id` ensures the generation ends with a EOS token
>>> outputs = model.generate(**inputs, max_new_tokens=10, forced_eos_token_id=tokenizer.eos_token_id)
>>> print(tokenizer.batch_decode(outputs)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7,<|endoftext|>
```
""" """
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
...@@ -1294,6 +1466,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): ...@@ -1294,6 +1466,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
the logits processor should only be used if necessary since it can slow down the generation method. the logits processor should only be used if necessary since it can slow down the generation method.
This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants
its use.
""" """
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
...@@ -1405,6 +1580,29 @@ class LogitNormalization(LogitsProcessor, LogitsWarper): ...@@ -1405,6 +1580,29 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses. the scores are normalized when comparing the hypotheses.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> import torch
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")
>>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
>>> # distribution, summing to 1
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(816.3250)
>>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
>>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
>>> print(torch.sum(torch.exp(outputs.scores[-1])))
tensor(1.0000)
```
""" """
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
...@@ -1416,8 +1614,36 @@ class LogitNormalization(LogitsProcessor, LogitsWarper): ...@@ -1416,8 +1614,36 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
r""" r"""
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
sampled at the begining of the generation. not generated at the begining. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
Examples:
```python
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
>>> # it can't generate and EOS token in the first iteration, but it can in the others.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 50256]) # 1 (and not 0) is the first freely generated token
tensor(-inf)
>>> print(outputs.scores[-1][0, 50256]) # in other places we can see some probability mass for EOS
tensor(29.9010)
>>> # If we disable `begin_suppress_tokens`, we can generate EOS in the first iteration.
>>> outputs = model.generate(
... **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
... )
>>> print(outputs.scores[1][0, 50256])
tensor(11.2027)
```
""" """
def __init__(self, begin_suppress_tokens, begin_index): def __init__(self, begin_suppress_tokens, begin_index):
...@@ -1433,8 +1659,33 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): ...@@ -1433,8 +1659,33 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
class SuppressTokensLogitsProcessor(LogitsProcessor): class SuppressTokensLogitsProcessor(LogitsProcessor):
r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they r"""
are not sampled.""" This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
that they are not generated. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
Examples:
```python
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(outputs.scores[1][0, 1]) # 1 (and not 0) is the first freely generated token
tensor(-inf)
>>> # If we disable `suppress_tokens`, we can generate it.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
>>> print(outputs.scores[1][0, 1])
tensor(5.7738)
```
"""
def __init__(self, suppress_tokens): def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens) self.suppress_tokens = list(suppress_tokens)
...@@ -1446,9 +1697,42 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ...@@ -1446,9 +1697,42 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
class ForceTokensLogitsProcessor(LogitsProcessor): class ForceTokensLogitsProcessor(LogitsProcessor):
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token r"""
indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
sampled at their corresponding index.""" indices that will be forced before generation. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index. Originally created for
[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
Examples:
```python
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> # This Whisper model forces the generation to start with `50362` at the first position by default, i.e.
>>> # `"forced_decoder_ids": [[1, 50362]]`. This means all other tokens are masked out.
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
True
>>> print(outputs.scores[0][0, 50362])
tensor(0.)
>>> # If we disable `forced_decoder_ids`, we stop seeing that effect
>>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
>>> print(
... all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
... )
False
>>> print(outputs.scores[0][0, 50362])
tensor(19.3140)
```
"""
def __init__(self, force_token_map: List[List[int]]): def __init__(self, force_token_map: List[List[int]]):
self.force_token_map = dict(force_token_map) self.force_token_map = dict(force_token_map)
...@@ -1492,7 +1776,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1492,7 +1776,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
Examples: Examples:
``` python ``` python
>>> import torch >>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig >>> from transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig
>>> from datasets import load_dataset >>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
...@@ -1588,18 +1872,42 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -1588,18 +1872,42 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension, r"""
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
where the first half correspond to the conditional logits (predicted from the input prompt) and the second half where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
See [the paper](https://arxiv.org/abs/2306.05284) for more information. See [the paper](https://arxiv.org/abs/2306.05284) for more information.
<Tip warning={true}>
This logits processor is exclusivelly compatible with
[MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen)
</Tip>
Args: Args:
guidance_scale (float): guidance_scale (float):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality. prompt, usually at the expense of poorer quality.
Examples:
```python
>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
>>> inputs = processor(
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
... padding=True,
... return_tensors="pt",
... )
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
```
""" """
def __init__(self, guidance_scale): def __init__(self, guidance_scale):
...@@ -1629,7 +1937,15 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): ...@@ -1629,7 +1937,15 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
class AlternatingCodebooksLogitsProcessor(LogitsProcessor): class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel. [`LogitsProcessor`] enforcing alternated generation between the two codebooks of Bark.
<Tip warning={true}>
This logits processor is exclusivelly compatible with
[Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
for examples.
</Tip>
Args: Args:
input_start_len (`int`): input_start_len (`int`):
...@@ -1664,10 +1980,10 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor): ...@@ -1664,10 +1980,10 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""Logits processor for Classifier-Free Guidance (CFG). The processors r"""
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits, Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
the `unconditional_ids` branch. The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.
See [the paper](https://arxiv.org/abs/2306.17806) for more information. See [the paper](https://arxiv.org/abs/2306.17806) for more information.
...@@ -1784,6 +2100,13 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): ...@@ -1784,6 +2100,13 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
<Tip warning={true}>
This logits processor is exclusivelly compatible with
[Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples.
</Tip>
Args: Args:
eos_token_id (`Union[int, List[int]]`): eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
......
...@@ -1031,16 +1031,9 @@ class GenerationMixin: ...@@ -1031,16 +1031,9 @@ class GenerationMixin:
generation_config.encoder_no_repeat_ngram_size is not None generation_config.encoder_no_repeat_ngram_size is not None
and generation_config.encoder_no_repeat_ngram_size > 0 and generation_config.encoder_no_repeat_ngram_size > 0
): ):
if self.config.is_encoder_decoder: processors.append(
processors.append( EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids)
EncoderNoRepeatNGramLogitsProcessor( )
generation_config.encoder_no_repeat_ngram_size, encoder_input_ids
)
)
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
)
if generation_config.bad_words_ids is not None: if generation_config.bad_words_ids is not None:
processors.append( processors.append(
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment