Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9efec114
"app/tray/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "91dfbb1bba3318c1604e75ecc95e23b2991001db"
Unverified
Commit
9efec114
authored
Jan 19, 2024
by
Ofir Zafrir
Committed by
GitHub
Jan 19, 2024
Browse files
Fix `_speculative_sampling` implementation (#28508)
parent
d1578159
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
22 deletions
+70
-22
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+7
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+25
-19
tests/generation/test_utils.py
tests/generation/test_utils.py
+38
-0
No files found.
src/transformers/generation/candidate_generator.py
View file @
9efec114
...
...
@@ -171,12 +171,16 @@ class AssistedCandidateGenerator(CandidateGenerator):
"""
input_ids
=
input_ids
.
to
(
self
.
assistant_model
.
device
)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len
=
input_ids
.
shape
[
-
1
]
max_new_tokens
=
min
(
int
(
self
.
num_assistant_tokens
),
self
.
generation_config
.
max_length
-
new_cur_len
-
1
)
if
max_new_tokens
==
0
:
return
input_ids
,
None
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values
=
self
.
assistant_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
if
has_past_key_values
:
new_cur_len
=
input_ids
.
shape
[
-
1
]
new_cache_size
=
new_cur_len
-
1
self
.
assistant_kwargs
[
"past_key_values"
]
=
_crop_past_key_values
(
self
.
assistant_model
,
self
.
assistant_kwargs
[
"past_key_values"
],
new_cache_size
-
1
...
...
@@ -190,7 +194,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs
=
{
self
.
input_ids_key
:
input_ids
,
"max_new_tokens"
:
int
(
self
.
num_assistant
_tokens
)
,
"max_new_tokens"
:
max_new
_tokens
,
"generation_config"
:
self
.
generation_config
,
"logits_processor"
:
self
.
logits_processor
,
}
...
...
src/transformers/generation/utils.py
View file @
9efec114
...
...
@@ -4404,7 +4404,7 @@ class GenerationMixin:
else
:
selected_tokens
=
new_logits
.
argmax
(
dim
=-
1
)
candidate_new_tokens
=
candidate_input_ids
[:,
-
candidate
_len
gth
:]
candidate_new_tokens
=
candidate_input_ids
[:,
cur
_len
:]
n_matches
=
((
~
(
candidate_new_tokens
==
selected_tokens
[:,
:
-
1
])).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# Ensure we don't generate beyond max_len or an EOS token
...
...
@@ -4540,12 +4540,13 @@ def _speculative_sampling(
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
new_candidate_input_ids
=
candidate_input_ids
[:,
-
candidate_length
:]
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q
=
candidate_logits
.
softmax
(
dim
=-
1
)
q_i
=
q
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]
].
squeeze
(
0
,
1
)
q_i
=
q
[:,
torch
.
arange
(
candidate_length
),
new_
candidate_input_ids
].
squeeze
(
0
,
1
)
p
=
new_logits
.
softmax
(
dim
=-
1
)
p_i
=
p
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]
].
squeeze
(
0
,
1
)
p_i
=
p
[:,
torch
.
arange
(
candidate_length
),
new_
candidate_input_ids
].
squeeze
(
0
,
1
)
probability_ratio
=
p_i
/
q_i
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
...
...
@@ -4553,28 +4554,33 @@ def _speculative_sampling(
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i
=
torch
.
rand_like
(
probability_ratio
)
is_accepted
=
r_i
<=
probability_ratio
n_matches
=
(
~
is_accepted
.
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# this is `n` in algorithm 1
n_matches
=
(
(
~
is_accepted
)
.
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# this is `n` in algorithm 1
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
# due to acceptance on EOS we fix `n_matches`
n_matches
-=
1
n_matches
=
min
(
n_matches
,
max_matches
)
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma
=
candidate_logits
.
shape
[
1
]
p_n_plus_1
=
p
[:,
n_matches
,
:]
if
n_matches
<
gamma
:
q_n_plus_1
=
q
[:,
n_matches
,
:]
p_prime
=
torch
.
clamp
((
p_n_plus_1
-
q_n_plus_1
),
min
=
0
).
softmax
(
dim
=-
1
)
valid_tokens
=
new_candidate_input_ids
[:,
:
n_matches
+
1
]
else
:
p_prime
=
p_n_plus_1
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
n_matches
=
min
(
n_matches
,
max_matches
)
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma
=
min
(
candidate_logits
.
shape
[
1
],
max_matches
)
p_n_plus_1
=
p
[:,
n_matches
,
:]
if
n_matches
<
gamma
:
q_n_plus_1
=
q
[:,
n_matches
,
:]
p_prime
=
torch
.
clamp
((
p_n_plus_1
-
q_n_plus_1
),
min
=
0
)
p_prime
.
div_
(
p_prime
.
sum
())
else
:
p_prime
=
p_n_plus_1
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
# The selected tokens include the matches (if any) plus the next sampled tokens
if
n_matches
>
0
:
valid_tokens
=
torch
.
cat
((
candidate_input_ids
[:,
-
n_matches
:
],
t
),
dim
=-
1
)
else
:
valid_tokens
=
t
# The selected tokens include the matches (if any) plus the next sampled tokens
if
n_matches
>
0
:
valid_tokens
=
torch
.
cat
((
new_
candidate_input_ids
[:,
:
n_matches
],
t
),
dim
=-
1
)
else
:
valid_tokens
=
t
return
valid_tokens
,
n_matches
...
...
tests/generation/test_utils.py
View file @
9efec114
...
...
@@ -88,6 +88,7 @@ if is_torch_available():
TopKLogitsWarper
,
TopPLogitsWarper
,
)
from
transformers.generation.utils
import
_speculative_sampling
class
GenerationTesterMixin
:
...
...
@@ -2424,6 +2425,43 @@ class UtilsFunctionsTest(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
expected_output
,
output
,
atol
=
1e-12
))
def
test_speculative_sampling
(
self
):
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids
=
torch
.
tensor
([[
8
,
0
,
3
,
9
,
8
,
1
,
4
,
5
]])
# input tokens
candidate_logits
=
torch
.
tensor
(
[
[
[
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# generated 1
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# generated 4
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# generated 5
]
]
)
candidate_length
=
3
inf
=
float
(
"inf"
)
new_logits
=
torch
.
tensor
(
[
[
[
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# accepts 1
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# accepts 4
[
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
-
inf
,
10.0
,
-
inf
],
# rejects 5, accepts 8
[
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
,
-
10.0
],
# N/A
]
]
)
last_assistant_token_is_eos
=
False
max_matches
=
5
validated_tokens
,
n_matches
=
_speculative_sampling
(
candidate_input_ids
,
candidate_logits
,
candidate_length
,
new_logits
,
last_assistant_token_is_eos
,
max_matches
,
)
self
.
assertTrue
(
n_matches
.
item
()
==
2
)
self
.
assertTrue
(
validated_tokens
.
tolist
()[
0
]
==
[
1
,
4
,
8
])
@
require_torch
class
GenerationIntegrationTests
(
unittest
.
TestCase
,
GenerationIntegrationTestsMixin
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment