Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9efec114
Unverified
Commit
9efec114
authored
Jan 19, 2024
by
Ofir Zafrir
Committed by
GitHub
Jan 19, 2024
Browse files
Fix `_speculative_sampling` implementation (#28508)
parent
d1578159
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
22 deletions
+70
-22
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+7
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+25
-19
tests/generation/test_utils.py
tests/generation/test_utils.py
+38
-0
No files found.
src/transformers/generation/candidate_generator.py
View file @
9efec114
...
@@ -171,12 +171,16 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -171,12 +171,16 @@ class AssistedCandidateGenerator(CandidateGenerator):
"""
"""
input_ids
=
input_ids
.
to
(
self
.
assistant_model
.
device
)
input_ids
=
input_ids
.
to
(
self
.
assistant_model
.
device
)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len
=
input_ids
.
shape
[
-
1
]
max_new_tokens
=
min
(
int
(
self
.
num_assistant_tokens
),
self
.
generation_config
.
max_length
-
new_cur_len
-
1
)
if
max_new_tokens
==
0
:
return
input_ids
,
None
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values
=
self
.
assistant_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
has_past_key_values
=
self
.
assistant_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
if
has_past_key_values
:
if
has_past_key_values
:
new_cur_len
=
input_ids
.
shape
[
-
1
]
new_cache_size
=
new_cur_len
-
1
new_cache_size
=
new_cur_len
-
1
self
.
assistant_kwargs
[
"past_key_values"
]
=
_crop_past_key_values
(
self
.
assistant_kwargs
[
"past_key_values"
]
=
_crop_past_key_values
(
self
.
assistant_model
,
self
.
assistant_kwargs
[
"past_key_values"
],
new_cache_size
-
1
self
.
assistant_model
,
self
.
assistant_kwargs
[
"past_key_values"
],
new_cache_size
-
1
...
@@ -190,7 +194,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -190,7 +194,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
# 2. Forecast next N tokens using the assistant model.
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs
=
{
assistant_generation_kwargs
=
{
self
.
input_ids_key
:
input_ids
,
self
.
input_ids_key
:
input_ids
,
"max_new_tokens"
:
int
(
self
.
num_assistant
_tokens
)
,
"max_new_tokens"
:
max_new
_tokens
,
"generation_config"
:
self
.
generation_config
,
"generation_config"
:
self
.
generation_config
,
"logits_processor"
:
self
.
logits_processor
,
"logits_processor"
:
self
.
logits_processor
,
}
}
...
...
src/transformers/generation/utils.py
View file @
9efec114
...
@@ -4404,7 +4404,7 @@ class GenerationMixin:
...
@@ -4404,7 +4404,7 @@ class GenerationMixin:
else
:
else
:
selected_tokens
=
new_logits
.
argmax
(
dim
=-
1
)
selected_tokens
=
new_logits
.
argmax
(
dim
=-
1
)
candidate_new_tokens
=
candidate_input_ids
[:,
-
candidate
_len
gth
:]
candidate_new_tokens
=
candidate_input_ids
[:,
cur
_len
:]
n_matches
=
((
~
(
candidate_new_tokens
==
selected_tokens
[:,
:
-
1
])).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
n_matches
=
((
~
(
candidate_new_tokens
==
selected_tokens
[:,
:
-
1
])).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# Ensure we don't generate beyond max_len or an EOS token
# Ensure we don't generate beyond max_len or an EOS token
...
@@ -4540,12 +4540,13 @@ def _speculative_sampling(
...
@@ -4540,12 +4540,13 @@ def _speculative_sampling(
NOTE: Unless otherwise stated, the variable names match those in the paper.
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
"""
new_candidate_input_ids
=
candidate_input_ids
[:,
-
candidate_length
:]
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
# selected by the assistant, respectively.
q
=
candidate_logits
.
softmax
(
dim
=-
1
)
q
=
candidate_logits
.
softmax
(
dim
=-
1
)
q_i
=
q
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]
].
squeeze
(
0
,
1
)
q_i
=
q
[:,
torch
.
arange
(
candidate_length
),
new_
candidate_input_ids
].
squeeze
(
0
,
1
)
p
=
new_logits
.
softmax
(
dim
=-
1
)
p
=
new_logits
.
softmax
(
dim
=-
1
)
p_i
=
p
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]
].
squeeze
(
0
,
1
)
p_i
=
p
[:,
torch
.
arange
(
candidate_length
),
new_
candidate_input_ids
].
squeeze
(
0
,
1
)
probability_ratio
=
p_i
/
q_i
probability_ratio
=
p_i
/
q_i
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
...
@@ -4553,28 +4554,33 @@ def _speculative_sampling(
...
@@ -4553,28 +4554,33 @@ def _speculative_sampling(
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i
=
torch
.
rand_like
(
probability_ratio
)
r_i
=
torch
.
rand_like
(
probability_ratio
)
is_accepted
=
r_i
<=
probability_ratio
is_accepted
=
r_i
<=
probability_ratio
n_matches
=
(
~
is_accepted
.
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# this is `n` in algorithm 1
n_matches
=
(
(
~
is_accepted
)
.
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# this is `n` in algorithm 1
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
# due to acceptance on EOS we fix `n_matches`
n_matches
-=
1
n_matches
-=
1
n_matches
=
min
(
n_matches
,
max_matches
)
valid_tokens
=
new_candidate_input_ids
[:,
:
n_matches
+
1
]
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma
=
candidate_logits
.
shape
[
1
]
p_n_plus_1
=
p
[:,
n_matches
,
:]
if
n_matches
<
gamma
:
q_n_plus_1
=
q
[:,
n_matches
,
:]
p_prime
=
torch
.
clamp
((
p_n_plus_1
-
q_n_plus_1
),
min
=
0
).
softmax
(
dim
=-
1
)
else
:
else
:
p_prime
=
p_n_plus_1
n_matches
=
min
(
n_matches
,
max_matches
)
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma
=
min
(
candidate_logits
.
shape
[
1
],
max_matches
)
p_n_plus_1
=
p
[:,
n_matches
,
:]
if
n_matches
<
gamma
:
q_n_plus_1
=
q
[:,
n_matches
,
:]
p_prime
=
torch
.
clamp
((
p_n_plus_1
-
q_n_plus_1
),
min
=
0
)
p_prime
.
div_
(
p_prime
.
sum
())
else
:
p_prime
=
p_n_plus_1
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
# The selected tokens include the matches (if any) plus the next sampled tokens
# The selected tokens include the matches (if any) plus the next sampled tokens
if
n_matches
>
0
:
if
n_matches
>
0
:
valid_tokens
=
torch
.
cat
((
candidate_input_ids
[:,
-
n_matches
:
],
t
),
dim
=-
1
)
valid_tokens
=
torch
.
cat
((
new_
candidate_input_ids
[:,
:
n_matches
],
t
),
dim
=-
1
)
else
:
else
:
valid_tokens
=
t
valid_tokens
=
t
return
valid_tokens
,
n_matches
return
valid_tokens
,
n_matches
...
...
tests/generation/test_utils.py
View file @
9efec114
...
@@ -88,6 +88,7 @@ if is_torch_available():
...
@@ -88,6 +88,7 @@ if is_torch_available():
TopKLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
TopPLogitsWarper
,
)
)
from
transformers.generation.utils
import
_speculative_sampling
class
GenerationTesterMixin
:
class
GenerationTesterMixin
:
...
@@ -2424,6 +2425,43 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -2424,6 +2425,43 @@ class UtilsFunctionsTest(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
expected_output
,
output
,
atol
=
1e-12
))
self
.
assertTrue
(
torch
.
allclose
(
expected_output
,
output
,
atol
=
1e-12
))
def
test_speculative_sampling
(
self
):
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids
=
torch
.
tensor
([[
8
,
0
,
3
,
9
,
8
,
1
,
4
,
5
]])
# input tokens
candidate_logits
=
torch
.
tensor
(
[
[
[
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# generated 1
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# generated 4
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# generated 5
]
]
)
candidate_length
=
3
inf
=
float
(
"inf"
)
new_logits
=
torch
.
tensor
(
[
[
[
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# accepts 1
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# accepts 4
[
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
10.0
,
-
inf
],
# rejects 5, accepts 8
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# N/A
]
]
)
last_assistant_token_is_eos
=
False
max_matches
=
5
validated_tokens
,
n_matches
=
_speculative_sampling
(
candidate_input_ids
,
candidate_logits
,
candidate_length
,
new_logits
,
last_assistant_token_is_eos
,
max_matches
,
)
self
.
assertTrue
(
n_matches
.
item
()
==
2
)
self
.
assertTrue
(
validated_tokens
.
tolist
()[
0
]
==
[
1
,
4
,
8
])
@
require_torch
@
require_torch
class
GenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
class
GenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment