Unverified Commit 5f0801d1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: add SequenceBiasLogitsProcessor (#24334)

parent 45f71d79
...@@ -141,6 +141,9 @@ generation. ...@@ -141,6 +141,9 @@ generation.
[[autodoc]] NoRepeatNGramLogitsProcessor [[autodoc]] NoRepeatNGramLogitsProcessor
- __call__ - __call__
[[autodoc]] SequenceBiasLogitsProcessor
- __call__
[[autodoc]] NoBadWordsLogitsProcessor [[autodoc]] NoBadWordsLogitsProcessor
- __call__ - __call__
......
...@@ -970,6 +970,7 @@ else: ...@@ -970,6 +970,7 @@ else:
"PhrasalConstraint", "PhrasalConstraint",
"PrefixConstrainedLogitsProcessor", "PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor", "RepetitionPenaltyLogitsProcessor",
"SequenceBiasLogitsProcessor",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
...@@ -4733,6 +4734,7 @@ if TYPE_CHECKING: ...@@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
PhrasalConstraint, PhrasalConstraint,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
......
...@@ -56,6 +56,7 @@ else: ...@@ -56,6 +56,7 @@ else:
"NoRepeatNGramLogitsProcessor", "NoRepeatNGramLogitsProcessor",
"PrefixConstrainedLogitsProcessor", "PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor", "RepetitionPenaltyLogitsProcessor",
"SequenceBiasLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor", "EncoderRepetitionPenaltyLogitsProcessor",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
...@@ -182,6 +183,7 @@ if TYPE_CHECKING: ...@@ -182,6 +183,7 @@ if TYPE_CHECKING:
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
......
...@@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin):
no_repeat_ngram_size (`int`, *optional*, defaults to 0): no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once. If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[List[int]]`, *optional*): bad_words_ids(`List[List[int]]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the token ids of the words that List of list of token ids that are not allowed to be generated. Check
should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing the [`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
...@@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin): ...@@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123. of index 123.
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
> Parameters that define the output variables of `generate` > Parameters that define the output variables of `generate`
...@@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin):
self.suppress_tokens = kwargs.pop("suppress_tokens", None) self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
# Parameters that define the output variables of `generate` # Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import inspect import inspect
import math import math
from typing import Callable, Iterable, List, Optional, Tuple, Union from typing import Callable, Dict, Iterable, List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -539,140 +539,218 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -539,140 +539,218 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
return scores return scores
class NoBadWordsLogitsProcessor(LogitsProcessor): class SequenceBiasLogitsProcessor(LogitsProcessor):
""" """
[`LogitsProcessor`] that enforces that specified sequences will never be sampled. [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
one token, consider using beam methods (to gracefully work around partially completed sequences that have a
negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
<Tip>
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
</Tip>
Args: Args:
bad_words_ids (`List[List[int]]`): sequence_bias (`Dict[Tuple[int], float]`):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space` will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from completed (in the token selection step after this processor is applied).
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): Examples:
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
if eos_token_id is None: ```python
eos_token_id = [] >>> from transformers import AutoTokenizer, AutoModelForCausalLM
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list( >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
) >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
if len(word) == 1:
self.bad_words_id_length_1.append(word[0])
else:
self.bad_words_id_length_greater_than_1.append(word)
self.static_bad_words_mask: Optional[torch.LongTensor] = None >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Trump Jr
for banned_token_seq in self.bad_words_id_length_greater_than_1: >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
if len(banned_token_seq) == 0: >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
>>> def get_tokens_as_tuple(word):
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist()) >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens) >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Donald,
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Rumsfeld,
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Duck.
```
"""
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
self.sequence_bias = sequence_bias
self._validate_arguments()
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
self.sequences_length_greater_than_1 = []
self.length_1_bias = None
self.length_greather_than_1_bias = None
self.prepared_bias_variables = False
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
if not self.prepared_bias_variables:
self._prepare_bias_variables(scores)
# 2 - prepares an empty bias to add
bias = torch.zeros_like(scores)
# 3 - include the bias from length = 1
bias += self.length_1_bias
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
# may become complete this iteration.
matching_mask = torch.zeros_like(scores, dtype=torch.bool)
for sequence_ids in self.sequences_length_greater_than_1:
if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
continue
prefix_length = len(sequence_ids) - 1
last_token = sequence_ids[-1]
matching_rows = torch.eq(
input_ids[:, -prefix_length:],
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
).prod(dim=1)
matching_mask[:, last_token] |= matching_rows.bool()
bias += torch.where(matching_mask, self.length_greather_than_1_bias, 0.0)
# 5 - apply the bias to the scores
scores = scores + bias
return scores return scores
def _calc_static_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor: def _prepare_bias_variables(self, scores: torch.FloatTensor):
static_bad_words_mask = torch.zeros(scores.shape[1]) vocabulary_size = scores.shape[-1]
static_bad_words_mask[self.bad_words_id_length_1] = 1 sequence_bias = self.sequence_bias
return static_bad_words_mask.unsqueeze(0).to(scores.device).bool() tokens_with_bias = []
def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool: # Check biased tokens out of bounds
if len(tokens) == 0: invalid_biases = []
# if bad word tokens is just one token always ban it for sequence_ids in sequence_bias:
return True for token_id in sequence_ids:
elif len(tokens) > len(prev_tokens): if token_id >= vocabulary_size:
# if bad word tokens are longer then prev input_ids they can't be equal invalid_biases.append(token_id)
return False if len(invalid_biases) > 0:
raise ValueError(
f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
f"{invalid_biases}"
)
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic.
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
for sequence_ids, bias in sequence_bias.items():
if len(sequence_ids) == 1:
self.length_1_bias[sequence_ids[-1]] = bias
else: else:
return prev_tokens[-len(tokens) :] == tokens self.sequences_length_greater_than_1.append(sequence_ids)
if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0:
raise ValueError(
"Setting a bias on sequences that share a common token termination is not yet supported. "
"Please open an issue if you see this error message (after checking that it doesn't already "
"exist)."
)
self.length_greather_than_1_bias[sequence_ids[-1]] = bias
tokens_with_bias.append(sequence_ids[-1])
def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]: self.prepared_bias_variables = True
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_id_length_greater_than_1:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice) def _validate_arguments(self):
sequence_bias = self.sequence_bias
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
or len(sequence_ids) == 0
for sequence_ids in sequence_bias.keys()
):
raise ValueError(
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
f"{sequence_bias}."
)
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
return banned_tokens
def _set_scores_to_inf_for_banned_tokens( class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
self, scores: torch.Tensor, banned_tokens: List[List[int]]
) -> torch.Tensor:
""" """
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a [`LogitsProcessor`] that enforces that specified sequences will never be selected.
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
<Tip>
In order to get the token ids of the words that should not appear in the generated text, make sure to set
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
</Tip>
Args: Args:
scores: logits distribution of shape (batch size, vocabulary size) bad_words_ids (`List[List[int]]`):
banned_tokens: list of list of tokens to ban of length (batch_size) List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
""" """
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
# Eliminates invalid bad word IDs that are over the vocabulary size.
if token <= scores.shape[1]:
banned_mask_list.append([idx, token])
else:
logger.error(
f"An invalid bad word ID is defined: {token}. This ID is not contained in the "
"vocabulary, and is therefore ignored."
)
if not banned_mask_list and self.static_bad_words_mask is None:
return scores
else: def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
if banned_mask_list: self.bad_word_ids = bad_words_ids
indices = torch.ones(len(banned_mask_list)) self._validate_arguments()
banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: # Filter EOS token from bad_words_ids
# [ 0 1 1 ] if eos_token_id is None:
# [ 0 0 0 ] eos_token_id = []
# [ 1 0 0 ] if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
banned_mask = ( bad_words_ids = list(
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
.to(scores.device)
.to_dense()
.bool()
) )
if self.static_bad_words_mask is not None: # Forbidding a sequence is equivalent to setting its bias to -inf
banned_mask = torch.bitwise_or(banned_mask, self.static_bad_words_mask) sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
else: super().__init__(sequence_bias=sequence_bias)
banned_mask = self.static_bad_words_mask
scores = scores.masked_fill(banned_mask, -float("inf")) def _validate_arguments(self):
return scores bad_words_ids = self.bad_word_ids
if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
class PrefixConstrainedLogitsProcessor(LogitsProcessor): class PrefixConstrainedLogitsProcessor(LogitsProcessor):
......
...@@ -56,6 +56,7 @@ from .logits_process import ( ...@@ -56,6 +56,7 @@ from .logits_process import (
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor, SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor, SuppressTokensLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
...@@ -842,8 +843,9 @@ class GenerationMixin: ...@@ -842,8 +843,9 @@ class GenerationMixin:
# instantiate processors list # instantiate processors list
processors = LogitsProcessorList() processors = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files if generation_config.sequence_bias is not None:
# all samplers can be found in `generation_utils_samplers.py` processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
processors.append( processors.append(
HammingDiversityLogitsProcessor( HammingDiversityLogitsProcessor(
......
...@@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject): ...@@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class SequenceBiasLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class StoppingCriteria(metaclass=DummyObject): class StoppingCriteria(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -46,6 +46,7 @@ if is_torch_available(): ...@@ -46,6 +46,7 @@ if is_torch_available():
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
...@@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3)) self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
def test_bias_dist_processor(self):
vocab_size = 5
batch_size = 2
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
positive_bias = {(1,): 100.0, (4,): 100.0}
negative_bias = {(1, 0): -100.0, (0, 1, 2): -100.0, (1, 3, 1, 3): -100.0}
sequence_bias = {**positive_bias, **negative_bias}
# scores = 0 to facilitate checks
scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device)
bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias)
filtered_scores = bias_dist_proc(input_ids, scores.clone())
# batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2)
# batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3)
self.assertListEqual(
filtered_scores.tolist(), [[-100.0, 100.0, 0.0, -100.0, 100.0], [-100.0, 100.0, -100.0, 0.0, 100.0]]
)
def test_processor_list(self): def test_processor_list(self):
batch_size = 4 batch_size = 4
sequence_length = 10 sequence_length = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment