Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5f0801d1
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "54d0b1c27887709d660fef34c9e6d065ebde69fb"
Unverified
Commit
5f0801d1
authored
Jun 21, 2023
by
Joao Gante
Committed by
GitHub
Jun 21, 2023
Browse files
Generate: add SequenceBiasLogitsProcessor (#24334)
parent
45f71d79
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
241 additions
and
123 deletions
+241
-123
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+3
-0
src/transformers/__init__.py
src/transformers/__init__.py
+2
-0
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+2
-0
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+7
-5
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+194
-116
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+4
-2
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+7
-0
tests/generation/test_logits_process.py
tests/generation/test_logits_process.py
+22
-0
No files found.
docs/source/en/internal/generation_utils.md
View file @
5f0801d1
...
@@ -141,6 +141,9 @@ generation.
...
@@ -141,6 +141,9 @@ generation.
[[autodoc]] NoRepeatNGramLogitsProcessor
[[autodoc]] NoRepeatNGramLogitsProcessor
-
__call__
-
__call__
[[autodoc]] SequenceBiasLogitsProcessor
-
__call__
[[autodoc]] NoBadWordsLogitsProcessor
[[autodoc]] NoBadWordsLogitsProcessor
-
__call__
-
__call__
...
...
src/transformers/__init__.py
View file @
5f0801d1
...
@@ -970,6 +970,7 @@ else:
...
@@ -970,6 +970,7 @@ else:
"PhrasalConstraint"
,
"PhrasalConstraint"
,
"PrefixConstrainedLogitsProcessor"
,
"PrefixConstrainedLogitsProcessor"
,
"RepetitionPenaltyLogitsProcessor"
,
"RepetitionPenaltyLogitsProcessor"
,
"SequenceBiasLogitsProcessor"
,
"StoppingCriteria"
,
"StoppingCriteria"
,
"StoppingCriteriaList"
,
"StoppingCriteriaList"
,
"TemperatureLogitsWarper"
,
"TemperatureLogitsWarper"
,
...
@@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
...
@@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
PhrasalConstraint
,
PhrasalConstraint
,
PrefixConstrainedLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
SequenceBiasLogitsProcessor
,
StoppingCriteria
,
StoppingCriteria
,
StoppingCriteriaList
,
StoppingCriteriaList
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
...
...
src/transformers/generation/__init__.py
View file @
5f0801d1
...
@@ -56,6 +56,7 @@ else:
...
@@ -56,6 +56,7 @@ else:
"NoRepeatNGramLogitsProcessor"
,
"NoRepeatNGramLogitsProcessor"
,
"PrefixConstrainedLogitsProcessor"
,
"PrefixConstrainedLogitsProcessor"
,
"RepetitionPenaltyLogitsProcessor"
,
"RepetitionPenaltyLogitsProcessor"
,
"SequenceBiasLogitsProcessor"
,
"EncoderRepetitionPenaltyLogitsProcessor"
,
"EncoderRepetitionPenaltyLogitsProcessor"
,
"TemperatureLogitsWarper"
,
"TemperatureLogitsWarper"
,
"TopKLogitsWarper"
,
"TopKLogitsWarper"
,
...
@@ -182,6 +183,7 @@ if TYPE_CHECKING:
...
@@ -182,6 +183,7 @@ if TYPE_CHECKING:
NoRepeatNGramLogitsProcessor
,
NoRepeatNGramLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
SequenceBiasLogitsProcessor
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
TopPLogitsWarper
,
...
...
src/transformers/generation/configuration_utils.py
View file @
5f0801d1
...
@@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin):
...
@@ -142,11 +142,8 @@ class GenerationConfig(PushToHubMixin):
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[List[int]]`, *optional*):
bad_words_ids(`List[List[int]]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
List of list of token ids that are not allowed to be generated. Check
should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing the
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this
...
@@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin):
...
@@ -183,6 +180,10 @@ class GenerationConfig(PushToHubMixin):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
of index 123.
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
> Parameters that define the output variables of `generate`
> Parameters that define the output variables of `generate`
...
@@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -262,6 +263,7 @@ class GenerationConfig(PushToHubMixin):
self
.
suppress_tokens
=
kwargs
.
pop
(
"suppress_tokens"
,
None
)
self
.
suppress_tokens
=
kwargs
.
pop
(
"suppress_tokens"
,
None
)
self
.
begin_suppress_tokens
=
kwargs
.
pop
(
"begin_suppress_tokens"
,
None
)
self
.
begin_suppress_tokens
=
kwargs
.
pop
(
"begin_suppress_tokens"
,
None
)
self
.
forced_decoder_ids
=
kwargs
.
pop
(
"forced_decoder_ids"
,
None
)
self
.
forced_decoder_ids
=
kwargs
.
pop
(
"forced_decoder_ids"
,
None
)
self
.
sequence_bias
=
kwargs
.
pop
(
"sequence_bias"
,
None
)
# Parameters that define the output variables of `generate`
# Parameters that define the output variables of `generate`
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
...
...
src/transformers/generation/logits_process.py
View file @
5f0801d1
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
inspect
import
inspect
import
math
import
math
from
typing
import
Callable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -539,140 +539,218 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
...
@@ -539,140 +539,218 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
return
scores
return
scores
class
NoBadWord
sLogitsProcessor
(
LogitsProcessor
):
class
SequenceBia
sLogitsProcessor
(
LogitsProcessor
):
"""
"""
[`LogitsProcessor`] that enforces that specified sequences will never be sampled.
[`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
one token, consider using beam methods (to gracefully work around partially completed sequences that have a
negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
<Tip>
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
</Tip>
Args:
Args:
bad_words_ids (`List[List[int]]`):
sequence_bias (`Dict[Tuple[int], float]`):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
completed (in the token selection step after this processor is applied).
`pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
eos_token_id (`Union[int, List[int]]`):
Examples:
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Trump Jr
>>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)
>>> def get_tokens_as_tuple(word):
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald J. Donald,
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Rumsfeld,
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
The full name of Donald is Donald Duck.
```
"""
"""
def
__init__
(
self
,
bad_words_ids
:
List
[
List
[
int
]],
eos_token_id
:
Union
[
int
,
List
[
int
]]):
def
__init__
(
self
,
sequence_bias
:
Dict
[
Tuple
[
int
],
float
]):
if
not
isinstance
(
bad_words_ids
,
List
)
or
len
(
bad_words_ids
)
==
0
:
self
.
sequence_bias
=
sequence_bias
raise
ValueError
(
f
"`bad_words_ids` has to be a non-empty list, but is
{
bad_words_ids
}
."
)
self
.
_validate_arguments
()
if
any
(
not
isinstance
(
bad_word_ids
,
list
)
for
bad_word_ids
in
bad_words_ids
):
raise
ValueError
(
f
"`bad_words_ids` has to be a list of lists, but is
{
bad_words_ids
}
."
)
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
self
.
sequences_length_greater_than_1
=
[]
self
.
length_1_bias
=
None
self
.
length_greather_than_1_bias
=
None
self
.
prepared_bias_variables
=
False
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
if
not
self
.
prepared_bias_variables
:
self
.
_prepare_bias_variables
(
scores
)
# 2 - prepares an empty bias to add
bias
=
torch
.
zeros_like
(
scores
)
# 3 - include the bias from length = 1
bias
+=
self
.
length_1_bias
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
# may become complete this iteration.
matching_mask
=
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
bool
)
for
sequence_ids
in
self
.
sequences_length_greater_than_1
:
if
len
(
sequence_ids
)
>
input_ids
.
shape
[
1
]:
# the sequence is longer than the context, ignore
continue
prefix_length
=
len
(
sequence_ids
)
-
1
last_token
=
sequence_ids
[
-
1
]
matching_rows
=
torch
.
eq
(
input_ids
[:,
-
prefix_length
:],
torch
.
tensor
(
sequence_ids
[:
-
1
],
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
),
).
prod
(
dim
=
1
)
matching_mask
[:,
last_token
]
|=
matching_rows
.
bool
()
bias
+=
torch
.
where
(
matching_mask
,
self
.
length_greather_than_1_bias
,
0.0
)
# 5 - apply the bias to the scores
scores
=
scores
+
bias
return
scores
def
_prepare_bias_variables
(
self
,
scores
:
torch
.
FloatTensor
):
vocabulary_size
=
scores
.
shape
[
-
1
]
sequence_bias
=
self
.
sequence_bias
tokens_with_bias
=
[]
# Check biased tokens out of bounds
invalid_biases
=
[]
for
sequence_ids
in
sequence_bias
:
for
token_id
in
sequence_ids
:
if
token_id
>=
vocabulary_size
:
invalid_biases
.
append
(
token_id
)
if
len
(
invalid_biases
)
>
0
:
raise
ValueError
(
f
"The model vocabulary size is
{
vocabulary_size
}
, but the following tokens were being biased: "
f
"
{
invalid_biases
}
"
)
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic.
self
.
length_1_bias
=
torch
.
zeros
((
vocabulary_size
,),
dtype
=
torch
.
float
).
to
(
scores
.
device
)
self
.
length_greather_than_1_bias
=
torch
.
zeros
((
vocabulary_size
,),
dtype
=
torch
.
float
).
to
(
scores
.
device
)
for
sequence_ids
,
bias
in
sequence_bias
.
items
():
if
len
(
sequence_ids
)
==
1
:
self
.
length_1_bias
[
sequence_ids
[
-
1
]]
=
bias
else
:
self
.
sequences_length_greater_than_1
.
append
(
sequence_ids
)
if
self
.
length_greather_than_1_bias
[
sequence_ids
[
-
1
]]
!=
0.0
:
raise
ValueError
(
"Setting a bias on sequences that share a common token termination is not yet supported. "
"Please open an issue if you see this error message (after checking that it doesn't already "
"exist)."
)
self
.
length_greather_than_1_bias
[
sequence_ids
[
-
1
]]
=
bias
tokens_with_bias
.
append
(
sequence_ids
[
-
1
])
self
.
prepared_bias_variables
=
True
def
_validate_arguments
(
self
):
sequence_bias
=
self
.
sequence_bias
if
not
isinstance
(
sequence_bias
,
dict
)
or
len
(
sequence_bias
)
==
0
:
raise
ValueError
(
f
"`sequence_bias` has to be a non-empty dictionary, but is
{
sequence_bias
}
."
)
if
any
(
not
isinstance
(
sequence_ids
,
tuple
)
for
sequence_ids
in
sequence_bias
.
keys
()):
raise
ValueError
(
f
"`sequence_bias` has to be a dict with tuples as keys, but is
{
sequence_bias
}
."
)
if
any
(
if
any
(
any
((
not
isinstance
(
token_id
,
(
int
,
np
.
integer
))
or
token_id
<
0
)
for
token_id
in
bad_word_ids
)
any
((
not
isinstance
(
token_id
,
(
int
,
np
.
integer
))
or
token_id
<
0
)
for
token_id
in
sequence_ids
)
for
bad_word_ids
in
bad_words_ids
or
len
(
sequence_ids
)
==
0
for
sequence_ids
in
sequence_bias
.
keys
()
):
):
raise
ValueError
(
raise
ValueError
(
f
"Each list in `bad_words_ids` has to be a list of positive integers, but is
{
bad_words_ids
}
."
f
"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
f
"
{
sequence_bias
}
."
)
)
if
any
(
not
isinstance
(
bias
,
float
)
for
bias
in
sequence_bias
.
values
()):
raise
ValueError
(
f
"`sequence_bias` has to be a dict with floats as values, but is
{
sequence_bias
}
."
)
if
eos_token_id
is
None
:
eos_token_id
=
[]
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
bad_words_ids
=
list
(
class
NoBadWordsLogitsProcessor
(
SequenceBiasLogitsProcessor
):
filter
(
lambda
bad_token_seq
:
all
([
bad_token_seq
!=
[
i
]
for
i
in
eos_token_id
]),
bad_words_ids
)
"""
)
[`LogitsProcessor`] that enforces that specified sequences will never be selected.
self
.
bad_words_id_length_1
=
[]
self
.
bad_words_id_length_greater_than_1
=
[]
for
word
in
bad_words_ids
:
if
len
(
word
)
==
1
:
self
.
bad_words_id_length_1
.
append
(
word
[
0
])
else
:
self
.
bad_words_id_length_greater_than_1
.
append
(
word
)
self
.
static_bad_words_mask
:
Optional
[
torch
.
LongTensor
]
=
None
<Tip>
for
banned_token_seq
in
self
.
bad_words_id_length_greater_than_1
:
In order to get the token ids of the words that should not appear in the generated text, make sure to set
if
len
(
banned_token_seq
)
==
0
:
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
raise
ValueError
(
f
"Banned words token sequences
{
bad_words_ids
}
cannot have an empty list"
)
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
</Tip>
if
self
.
static_bad_words_mask
is
None
and
len
(
self
.
bad_words_id_length_1
)
>
0
:
self
.
static_bad_words_mask
=
self
.
_calc_static_bad_word_mask
(
scores
)
dynamic_banned_tokens
=
self
.
_calc_banned_bad_words_ids
(
input_ids
.
tolist
())
Args:
scores
=
self
.
_set_scores_to_inf_for_banned_tokens
(
scores
,
dynamic_banned_tokens
)
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
return
scores
def
__init__
(
self
,
bad_words_ids
:
List
[
List
[
int
]],
eos_token_id
:
Union
[
int
,
List
[
int
]]):
self
.
bad_word_ids
=
bad_words_ids
self
.
_validate_arguments
()
def
_calc_static_bad_word_mask
(
self
,
scores
:
torch
.
FloatTensor
)
->
torch
.
BoolTensor
:
# Filter EOS token from bad_words_ids
static_bad_words_mask
=
torch
.
zeros
(
scores
.
shape
[
1
])
if
eos_token_id
is
None
:
static_bad_words_mask
[
self
.
bad_words_id_length_1
]
=
1
eos_token_id
=
[]
return
static_bad_words_mask
.
unsqueeze
(
0
).
to
(
scores
.
device
).
bool
()
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
def
_tokens_match
(
self
,
prev_tokens
:
List
[
int
],
tokens
:
List
[
int
])
->
bool
:
bad_words_ids
=
list
(
if
len
(
tokens
)
==
0
:
filter
(
lambda
bad_token_seq
:
all
([
bad_token_seq
!=
[
i
]
for
i
in
eos_token_id
]),
bad_words_ids
)
# if bad word tokens is just one token always ban it
)
return
True
elif
len
(
tokens
)
>
len
(
prev_tokens
):
# if bad word tokens are longer then prev input_ids they can't be equal
return
False
else
:
return
prev_tokens
[
-
len
(
tokens
)
:]
==
tokens
def
_calc_banned_bad_words_ids
(
self
,
prev_input_ids
:
List
[
List
[
int
]])
->
Iterable
[
int
]:
banned_tokens
=
[]
for
prev_input_ids_slice
in
prev_input_ids
:
banned_tokens_slice
=
[]
for
banned_token_seq
in
self
.
bad_words_id_length_greater_than_1
:
if
self
.
_tokens_match
(
prev_input_ids_slice
,
banned_token_seq
[:
-
1
]):
banned_tokens_slice
.
append
(
banned_token_seq
[
-
1
])
banned_tokens
.
append
(
banned_tokens_slice
)
return
banned_tokens
def
_set_scores_to_inf_for_banned_tokens
(
self
,
scores
:
torch
.
Tensor
,
banned_tokens
:
List
[
List
[
int
]]
)
->
torch
.
Tensor
:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list
=
[]
for
idx
,
batch_banned_tokens
in
enumerate
(
banned_tokens
):
for
token
in
batch_banned_tokens
:
# Eliminates invalid bad word IDs that are over the vocabulary size.
if
token
<=
scores
.
shape
[
1
]:
banned_mask_list
.
append
([
idx
,
token
])
else
:
logger
.
error
(
f
"An invalid bad word ID is defined:
{
token
}
. This ID is not contained in the "
"vocabulary, and is therefore ignored."
)
if
not
banned_mask_list
and
self
.
static_bad_words_mask
is
None
:
return
scores
else
:
# Forbidding a sequence is equivalent to setting its bias to -inf
if
banned_mask_list
:
sequence_bias
=
{
tuple
(
sequence
):
float
(
"-inf"
)
for
sequence
in
bad_words_ids
}
indices
=
torch
.
ones
(
len
(
banned_mask_list
))
super
().
__init__
(
sequence_bias
=
sequence_bias
)
banned_mask
=
torch
.
LongTensor
(
banned_mask_list
,
device
=
indices
.
device
)
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask
=
(
torch
.
sparse
.
LongTensor
(
banned_mask
.
t
(),
indices
,
scores
.
size
())
.
to
(
scores
.
device
)
.
to_dense
()
.
bool
()
)
if
self
.
static_bad_words_mask
is
not
None
:
banned_mask
=
torch
.
bitwise_or
(
banned_mask
,
self
.
static_bad_words_mask
)
else
:
banned_mask
=
self
.
static_bad_words_mask
scores
=
scores
.
masked_fill
(
banned_mask
,
-
float
(
"inf"
))
def
_validate_arguments
(
self
):
return
scores
bad_words_ids
=
self
.
bad_word_ids
if
not
isinstance
(
bad_words_ids
,
list
)
or
len
(
bad_words_ids
)
==
0
:
raise
ValueError
(
f
"`bad_words_ids` has to be a non-empty list, but is
{
bad_words_ids
}
."
)
if
any
(
not
isinstance
(
bad_word_ids
,
list
)
for
bad_word_ids
in
bad_words_ids
):
raise
ValueError
(
f
"`bad_words_ids` has to be a list of lists, but is
{
bad_words_ids
}
."
)
if
any
(
any
((
not
isinstance
(
token_id
,
(
int
,
np
.
integer
))
or
token_id
<
0
)
for
token_id
in
bad_word_ids
)
for
bad_word_ids
in
bad_words_ids
):
raise
ValueError
(
f
"Each list in `bad_words_ids` has to be a list of positive integers, but is
{
bad_words_ids
}
."
)
class
PrefixConstrainedLogitsProcessor
(
LogitsProcessor
):
class
PrefixConstrainedLogitsProcessor
(
LogitsProcessor
):
...
...
src/transformers/generation/utils.py
View file @
5f0801d1
...
@@ -56,6 +56,7 @@ from .logits_process import (
...
@@ -56,6 +56,7 @@ from .logits_process import (
NoRepeatNGramLogitsProcessor
,
NoRepeatNGramLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
SequenceBiasLogitsProcessor
,
SuppressTokensAtBeginLogitsProcessor
,
SuppressTokensAtBeginLogitsProcessor
,
SuppressTokensLogitsProcessor
,
SuppressTokensLogitsProcessor
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
...
@@ -842,8 +843,9 @@ class GenerationMixin:
...
@@ -842,8 +843,9 @@ class GenerationMixin:
# instantiate processors list
# instantiate processors list
processors
=
LogitsProcessorList
()
processors
=
LogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
if
generation_config
.
sequence_bias
is
not
None
:
# all samplers can be found in `generation_utils_samplers.py`
processors
.
append
(
SequenceBiasLogitsProcessor
(
sequence_bias
=
generation_config
.
sequence_bias
))
if
generation_config
.
diversity_penalty
is
not
None
and
generation_config
.
diversity_penalty
>
0.0
:
if
generation_config
.
diversity_penalty
is
not
None
and
generation_config
.
diversity_penalty
>
0.0
:
processors
.
append
(
processors
.
append
(
HammingDiversityLogitsProcessor
(
HammingDiversityLogitsProcessor
(
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
5f0801d1
...
@@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
...
@@ -240,6 +240,13 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
class
SequenceBiasLogitsProcessor
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
StoppingCriteria
(
metaclass
=
DummyObject
):
class
StoppingCriteria
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
_backends
=
[
"torch"
]
...
...
tests/generation/test_logits_process.py
View file @
5f0801d1
...
@@ -46,6 +46,7 @@ if is_torch_available():
...
@@ -46,6 +46,7 @@ if is_torch_available():
NoRepeatNGramLogitsProcessor
,
NoRepeatNGramLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
PrefixConstrainedLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
RepetitionPenaltyLogitsProcessor
,
SequenceBiasLogitsProcessor
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
TopPLogitsWarper
,
...
@@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
...
@@ -512,6 +513,27 @@ class LogitsProcessorTest(unittest.TestCase):
filtered_scores
=
no_bad_words_dist_proc
(
input_ids
,
scores
.
clone
())
filtered_scores
=
no_bad_words_dist_proc
(
input_ids
,
scores
.
clone
())
self
.
assertTrue
(
torch
.
allclose
(
scores
,
filtered_scores
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
scores
,
filtered_scores
,
atol
=
1e-3
))
def
test_bias_dist_processor
(
self
):
vocab_size
=
5
batch_size
=
2
input_ids
=
torch
.
tensor
([[
0
,
1
,
3
,
1
],
[
0
,
1
,
0
,
1
]],
device
=
torch_device
,
dtype
=
torch
.
long
)
positive_bias
=
{(
1
,):
100.0
,
(
4
,):
100.0
}
negative_bias
=
{(
1
,
0
):
-
100.0
,
(
0
,
1
,
2
):
-
100.0
,
(
1
,
3
,
1
,
3
):
-
100.0
}
sequence_bias
=
{
**
positive_bias
,
**
negative_bias
}
# scores = 0 to facilitate checks
scores
=
torch
.
zeros
((
batch_size
,
vocab_size
),
dtype
=
torch
.
float
,
device
=
torch_device
)
bias_dist_proc
=
SequenceBiasLogitsProcessor
(
sequence_bias
=
sequence_bias
)
filtered_scores
=
bias_dist_proc
(
input_ids
,
scores
.
clone
())
# batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2)
# batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3)
self
.
assertListEqual
(
filtered_scores
.
tolist
(),
[[
-
100.0
,
100.0
,
0.0
,
-
100.0
,
100.0
],
[
-
100.0
,
100.0
,
-
100.0
,
0.0
,
100.0
]]
)
def
test_processor_list
(
self
):
def
test_processor_list
(
self
):
batch_size
=
4
batch_size
=
4
sequence_length
=
10
sequence_length
=
10
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment