Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
89136ff7
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cf416764f4d08bd24dc625a706e0ad7540ffd2c0"
Unverified
Commit
89136ff7
authored
Jul 20, 2023
by
Joao Gante
Committed by
GitHub
Jul 20, 2023
Browse files
Generate: sequence bias can handle same terminations (#24822)
parent
37d8611a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
28 deletions
+11
-28
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+8
-28
tests/generation/test_logits_process.py
tests/generation/test_logits_process.py
+3
-0
No files found.
src/transformers/generation/logits_process.py
View file @
89136ff7
...
@@ -624,9 +624,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
...
@@ -624,9 +624,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
# is infered in the first usage, which inhibits initializing here)
self
.
sequences_length_greater_than_1
=
[]
self
.
length_1_bias
=
None
self
.
length_1_bias
=
None
self
.
length_greather_than_1_bias
=
None
self
.
prepared_bias_variables
=
False
self
.
prepared_bias_variables
=
False
@
add_start_docstrings
(
LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
LOGITS_PROCESSOR_INPUTS_DOCSTRING
)
...
@@ -642,11 +640,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
...
@@ -642,11 +640,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
bias
+=
self
.
length_1_bias
bias
+=
self
.
length_1_bias
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
for
sequence_ids
,
sequence_bias
in
self
.
sequence_bias
.
items
():
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
if
len
(
sequence_ids
)
==
1
:
# the sequence is of length 1, already applied
# may become complete this iteration.
continue
matching_mask
=
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
bool
)
for
sequence_ids
in
self
.
sequences_length_greater_than_1
:
if
len
(
sequence_ids
)
>
input_ids
.
shape
[
1
]:
# the sequence is longer than the context, ignore
if
len
(
sequence_ids
)
>
input_ids
.
shape
[
1
]:
# the sequence is longer than the context, ignore
continue
continue
prefix_length
=
len
(
sequence_ids
)
-
1
prefix_length
=
len
(
sequence_ids
)
-
1
...
@@ -655,12 +651,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
...
@@ -655,12 +651,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
input_ids
[:,
-
prefix_length
:],
input_ids
[:,
-
prefix_length
:],
torch
.
tensor
(
sequence_ids
[:
-
1
],
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
),
torch
.
tensor
(
sequence_ids
[:
-
1
],
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
),
).
prod
(
dim
=
1
)
).
prod
(
dim
=
1
)
matching_mask
[:,
last_token
]
|=
matching_rows
.
bool
()
bias
[:,
last_token
]
+=
torch
.
where
(
bias
+=
torch
.
where
(
matching_rows
.
bool
(),
sequence_bias
,
torch
.
tensor
(
0.0
,
device
=
input_ids
.
device
)
matching_mask
,
)
self
.
length_greather_than_1_bias
,
torch
.
tensor
(
0.0
,
device
=
self
.
length_greather_than_1_bias
.
device
),
)
# 5 - apply the bias to the scores
# 5 - apply the bias to the scores
scores
=
scores
+
bias
scores
=
scores
+
bias
...
@@ -668,12 +661,10 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
...
@@ -668,12 +661,10 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
def
_prepare_bias_variables
(
self
,
scores
:
torch
.
FloatTensor
):
def
_prepare_bias_variables
(
self
,
scores
:
torch
.
FloatTensor
):
vocabulary_size
=
scores
.
shape
[
-
1
]
vocabulary_size
=
scores
.
shape
[
-
1
]
sequence_bias
=
self
.
sequence_bias
tokens_with_bias
=
[]
# Check biased tokens out of bounds
# Check biased tokens out of bounds
invalid_biases
=
[]
invalid_biases
=
[]
for
sequence_ids
in
sequence_bias
:
for
sequence_ids
in
self
.
sequence_bias
:
for
token_id
in
sequence_ids
:
for
token_id
in
sequence_ids
:
if
token_id
>=
vocabulary_size
:
if
token_id
>=
vocabulary_size
:
invalid_biases
.
append
(
token_id
)
invalid_biases
.
append
(
token_id
)
...
@@ -686,20 +677,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
...
@@ -686,20 +677,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic.
# with simpler logic.
self
.
length_1_bias
=
torch
.
zeros
((
vocabulary_size
,),
dtype
=
torch
.
float
).
to
(
scores
.
device
)
self
.
length_1_bias
=
torch
.
zeros
((
vocabulary_size
,),
dtype
=
torch
.
float
).
to
(
scores
.
device
)
self
.
length_greather_than_1_bias
=
torch
.
zeros
((
vocabulary_size
,),
dtype
=
torch
.
float
).
to
(
scores
.
device
)
for
sequence_ids
,
bias
in
self
.
sequence_bias
.
items
():
for
sequence_ids
,
bias
in
sequence_bias
.
items
():
if
len
(
sequence_ids
)
==
1
:
if
len
(
sequence_ids
)
==
1
:
self
.
length_1_bias
[
sequence_ids
[
-
1
]]
=
bias
self
.
length_1_bias
[
sequence_ids
[
-
1
]]
=
bias
else
:
self
.
sequences_length_greater_than_1
.
append
(
sequence_ids
)
if
self
.
length_greather_than_1_bias
[
sequence_ids
[
-
1
]]
!=
0.0
:
raise
ValueError
(
"Setting a bias on sequences that share a common token termination is not yet supported. "
"Please open an issue if you see this error message (after checking that it doesn't already "
"exist)."
)
self
.
length_greather_than_1_bias
[
sequence_ids
[
-
1
]]
=
bias
tokens_with_bias
.
append
(
sequence_ids
[
-
1
])
self
.
prepared_bias_variables
=
True
self
.
prepared_bias_variables
=
True
...
...
tests/generation/test_logits_process.py
View file @
89136ff7
...
@@ -520,6 +520,9 @@ class LogitsProcessorTest(unittest.TestCase):
...
@@ -520,6 +520,9 @@ class LogitsProcessorTest(unittest.TestCase):
input_ids
=
torch
.
tensor
([[
0
,
1
,
3
,
1
],
[
0
,
1
,
0
,
1
]],
device
=
torch_device
,
dtype
=
torch
.
long
)
input_ids
=
torch
.
tensor
([[
0
,
1
,
3
,
1
],
[
0
,
1
,
0
,
1
]],
device
=
torch_device
,
dtype
=
torch
.
long
)
positive_bias
=
{(
1
,):
100.0
,
(
4
,):
100.0
}
positive_bias
=
{(
1
,):
100.0
,
(
4
,):
100.0
}
negative_bias
=
{(
1
,
0
):
-
100.0
,
(
0
,
1
,
2
):
-
100.0
,
(
1
,
3
,
1
,
3
):
-
100.0
}
negative_bias
=
{(
1
,
0
):
-
100.0
,
(
0
,
1
,
2
):
-
100.0
,
(
1
,
3
,
1
,
3
):
-
100.0
}
# biases the same termination twice, to ensure we can handle overlapping terminations (it won't have an effect
# on the test cases, though)
negative_bias
.
update
({(
1
,
3
,
1
,
3
,
1
,
3
):
-
100.0
})
sequence_bias
=
{
**
positive_bias
,
**
negative_bias
}
sequence_bias
=
{
**
positive_bias
,
**
negative_bias
}
# scores = 0 to facilitate checks
# scores = 0 to facilitate checks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment