Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
79d64c49
Unverified
Commit
79d64c49
authored
Jan 09, 2024
by
Cade Daniel
Committed by
GitHub
Jan 09, 2024
Browse files
[Speculative decoding 1/9] Optimized rejection sampler (#2336)
parent
74cd5abd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
784 additions
and
0 deletions
+784
-0
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+392
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+392
-0
No files found.
tests/samplers/test_rejection_sampler.py
0 → 100644
View file @
79d64c49
"""Tests for rejection sampling."""
import
pytest
from
typing
import
List
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
def
mock_causal_accepted_tensor
(
k
:
int
,
last_accepted_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Generate an "accepted" tensor which should yield causally-accepted tokens
up to last accepted indices.
Tokens after last_accepted_indices+1 may also be accepted, although they
will not be causally accepted.
"""
batch_size
=
last_accepted_indices
.
shape
[
0
]
accepted
=
(
torch
.
arange
(
k
).
expand
(
batch_size
,
k
)
<=
last_accepted_indices
.
unsqueeze
(
-
1
).
broadcast_to
(
batch_size
,
k
)).
to
(
device
=
"cuda"
)
# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates
=
(
torch
.
arange
(
k
).
expand
(
batch_size
,
k
)
>
last_accepted_indices
.
unsqueeze
(
-
1
).
broadcast_to
(
batch_size
,
k
)
+
1
)
sprinkle
=
torch
.
rand
(
batch_size
,
k
,
device
=
"cuda"
)
>
0.5
accepted
[
sprinkle_candidates
]
=
sprinkle
[
sprinkle_candidates
]
return
accepted
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"which_tokens_accepted"
,
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed
(
seed
)
batch_size
=
10
k
=
5
vocab_size
=
3000
if
which_tokens_accepted
==
"all_tokens_accepted"
:
accepted
=
mock_causal_accepted_tensor
(
k
,
-
1
+
k
*
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
long
))
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
accepted
=
mock_causal_accepted_tensor
(
k
,
-
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
long
))
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
last_accepted_indices
=
torch
.
randint
(
low
=-
1
,
high
=
k
,
size
=
(
batch_size
,
))
accepted
=
mock_causal_accepted_tensor
(
k
,
last_accepted_indices
)
else
:
raise
AssertionError
()
recovered_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
if
which_tokens_accepted
==
"all_tokens_accepted"
:
# Expect all tokens to be equal to draft tokens.
assert
torch
.
equal
(
output_token_ids
[:,
:
-
1
],
draft_token_ids
)
# Expect all bonus tokens to be included.
assert
torch
.
equal
(
output_token_ids
[:,
-
1
:],
bonus_token_ids
)
elif
which_tokens_accepted
==
"no_tokens_accepted"
:
# Expect first token to be equal to recovered tokens.
assert
torch
.
equal
(
output_token_ids
[:,
0
],
recovered_token_ids
[:,
0
])
# Expect everything else to be -1.
assert
torch
.
equal
(
output_token_ids
[:,
1
:],
torch
.
ones_like
(
output_token_ids
[:,
1
:])
*
-
1
)
elif
which_tokens_accepted
==
"some_tokens_accepted"
:
recovered_plus_bonus
=
torch
.
cat
(
(
recovered_token_ids
,
bonus_token_ids
),
dim
=-
1
)
# Assert first rejected token is a recovered token or bonus token.
assert
torch
.
equal
(
recovered_plus_bonus
[
torch
.
arange
(
0
,
batch_size
),
last_accepted_indices
+
1
],
output_token_ids
[
torch
.
arange
(
0
,
batch_size
),
last_accepted_indices
+
1
])
# Assert every subsequent token is -1.
subsequent_mask
=
torch
.
arange
(
0
,
k
+
1
).
expand
(
batch_size
,
k
+
1
)
>=
(
last_accepted_indices
+
2
).
unsqueeze
(
-
1
)
assert
torch
.
all
(
output_token_ids
[
subsequent_mask
]
==
-
1
)
@
pytest
.
mark
.
parametrize
(
"k"
,
list
(
range
(
1
,
6
)))
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
torch
.
inference_mode
()
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
):
rejection_sampler
=
RejectionSampler
()
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
[
"bonus_token_ids"
,
"draft_token_ids"
])
@
torch
.
inference_mode
()
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
which_token_ids
:
str
):
k
=
3
batch_size
=
5
vocab_size
=
30_000
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
)
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
oob_token_ids
=
None
if
which_token_ids
==
"bonus_token_ids"
:
oob_token_ids
=
bonus_token_ids
elif
which_token_ids
==
"draft_token_ids"
:
oob_token_ids
=
draft_token_ids
else
:
raise
AssertionError
()
if
above_or_below_vocab_range
==
"above"
:
rogue_token_id
=
vocab_size
+
1
elif
above_or_below_vocab_range
==
"below"
:
rogue_token_id
=
-
1
else
:
raise
AssertionError
()
oob_token_ids
[
0
][
0
]
=
rogue_token_id
with
pytest
.
raises
(
AssertionError
):
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
torch
.
inference_mode
()
def
test_rejection_sampling_approximates_target_distribution
(
seed
:
int
,
draft_and_target_probs_equal
:
bool
):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
When draft_and_target_probs_equal=True, the draft and target
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
set_random_seed
(
seed
)
helper
=
_CorrectnessTestHelper
(
vocab_size
=
10
,
rejection_sampler
=
RejectionSampler
(),
)
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
draft_and_target_probs_equal
)
sample_sizes
=
[
10
,
100
,
1_000
,
10_000
,
100_000
]
distance_wrt_reference
=
[]
distance_wrt_target
=
[]
for
num_samples
in
sample_sizes
:
(
reference_vs_rejsample_dist
,
target_vs_rejsample_dist
)
=
helper
.
run_and_compare_distributions
(
draft_probs
,
target_probs
,
reference_probs
,
num_samples
,
)
distance_wrt_reference
.
append
(
reference_vs_rejsample_dist
)
distance_wrt_target
.
append
(
target_vs_rejsample_dist
)
relative_change_in_distance_wrt_target
=
get_ratio_first_to_last
(
distance_wrt_target
)
relative_change_in_distance_wrt_reference
=
get_ratio_first_to_last
(
distance_wrt_reference
)
print
(
f
"
{
num_samples
=
}
{
target_vs_rejsample_dist
=
:.
05
f
}
"
f
"
{
reference_vs_rejsample_dist
=
:.
05
f
}
"
)
print
(
f
"
{
num_samples
=
}
{
relative_change_in_distance_wrt_target
=
:.
02
f
}
"
f
"
{
relative_change_in_distance_wrt_reference
=
:.
02
f
}
"
)
relative_change_in_distance_wrt_target
=
get_ratio_first_to_last
(
distance_wrt_target
)
relative_change_in_distance_wrt_reference
=
get_ratio_first_to_last
(
distance_wrt_reference
)
expected_improvement_multiplier
=
20
assert
(
relative_change_in_distance_wrt_target
>
relative_change_in_distance_wrt_reference
*
expected_improvement_multiplier
)
def
get_ratio_first_to_last
(
elements
:
List
[
float
])
->
float
:
return
elements
[
0
]
/
elements
[
-
1
]
class
_CorrectnessTestHelper
:
"""Class that packages together logic required for the unit-level
rejection sampling correctness test.
"""
def
__init__
(
self
,
vocab_size
:
int
,
rejection_sampler
:
RejectionSampler
):
self
.
rejection_sampler
=
rejection_sampler
self
.
vocab_size
=
vocab_size
self
.
vocab_range
=
(
0
,
vocab_size
)
self
.
rejection_sampler
.
init_gpu_tensors
(
rank
=
0
)
# Keep test simple, use k=1
self
.
k
=
1
# Bonus tokens not used, but rejection sampler requires
# correct shape.
self
.
num_bonus_tokens
=
1
def
generate_probs_for_test
(
self
,
draft_and_target_probs_equal
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
draft_probs
,
target_probs
=
[
F
.
softmax
(
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
)
for
_
in
range
(
2
)
]
num_reference_probs
=
100
reference_probs
=
F
.
softmax
(
torch
.
rand
(
num_reference_probs
,
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
)
if
draft_and_target_probs_equal
:
target_probs
=
draft_probs
.
clone
()
return
draft_probs
,
target_probs
,
reference_probs
def
run_and_compare_distributions
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
reference_probs
:
torch
.
Tensor
,
num_samples
:
int
)
->
Tuple
[
float
,
float
]:
# Sample using rejection sampling.
rej_sample_probs
=
self
.
_estimate_rejection_sampling_pdf
(
draft_probs
,
target_probs
,
num_samples
)
# Average distance from reference probs.
reference_vs_rejsample_dist
=
torch
.
dist
(
reference_probs
,
rej_sample_probs
).
item
()
/
reference_probs
.
shape
[
0
]
target_vs_rejsample_dist
=
torch
.
dist
(
target_probs
,
rej_sample_probs
).
item
()
return
reference_vs_rejsample_dist
,
target_vs_rejsample_dist
def
_estimate_rejection_sampling_pdf
(
self
,
draft_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
# Repeat draft probs num_samples times.
draft_probs
=
draft_probs
.
reshape
(
1
,
self
.
k
,
self
.
vocab_size
).
repeat
(
num_samples
,
1
,
1
)
# Repeat target probs num_samples * k times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs
=
target_probs
.
reshape
(
1
,
1
,
self
.
vocab_size
).
repeat
(
num_samples
,
self
.
k
,
1
)
# Randomly sample draft token ids from draft probs.
draft_token_ids
=
torch
.
multinomial
(
draft_probs
[:,
0
,
:],
num_samples
=
1
,
replacement
=
True
).
reshape
(
num_samples
,
self
.
k
)
# Bonus tokens not used but required.
bonus_token_ids
=
torch
.
zeros
((
1
,
self
.
num_bonus_tokens
),
dtype
=
torch
.
int64
,
device
=
"cuda"
).
repeat
(
num_samples
,
1
)
# Get output tokens via rejection sampling.
output_token_ids
=
self
.
rejection_sampler
(
target_probs
.
to
(
"cuda"
),
bonus_token_ids
.
to
(
"cuda"
),
draft_probs
.
to
(
"cuda"
),
draft_token_ids
.
to
(
"cuda"
))
# Remove bonus tokens
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
()
# Estimate probability density function
hist
=
torch
.
histogram
(
output_token_ids
.
to
(
dtype
=
torch
.
float
,
device
=
"cpu"
),
bins
=
self
.
vocab_size
,
range
=
self
.
vocab_range
,
density
=
True
)
return
hist
.
hist
vllm/model_executor/layers/rejection_sampler.py
0 → 100644
View file @
79d64c49
from
typing
import
Tuple
,
Optional
from
functools
import
cached_property
import
torch
import
torch.nn
as
nn
import
torch.jit
class
RejectionSampler
(
nn
.
Module
):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
probs_dtype
=
torch
.
float32
self
.
token_id_dtype
=
torch
.
int64
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
self
.
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
device
=
f
"cuda:
{
rank
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_incorrect_dtype
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_inconsistent_device
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
bonus_token_ids
,
draft_token_ids
)
accepted
,
recovered_token_ids
=
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
def
_batch_modified_rejection_sampling
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size
,
k
,
vocab_size
=
draft_probs
.
shape
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
def
_get_accepted
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
same conditional probability according to the draft model, the token
is accepted with probability:
.. math::
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size
,
k
,
_
=
draft_probs
.
shape
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
target_probs
.
device
)[:,
None
]
probs_indicies
=
torch
.
arange
(
k
,
device
=
target_probs
.
device
)
# shape [batch_size, k]
selected_draft_probs
=
draft_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
# shape [batch_size, k]
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
uniform_rand
=
torch
.
rand
(
batch_size
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
accepted
=
uniform_rand
<
capped_ratio
return
accepted
def
_get_recovered_probs
(
self
,
target_probs
:
torch
.
Tensor
,
# [k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [k, vocab_size]
)
->
torch
.
Tensor
:
r
"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
according to the draft model:
.. math::
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
where :math:`(f(x))_+` is defined as:
.. math::
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_
,
k
,
_
=
draft_probs
.
shape
# shape [batch_size, k, vocab_size]
difference
=
target_probs
-
draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f
=
torch
.
clamp
(
difference
,
min
=
self
.
_smallest_positive_value
)
# shape [batch_size, k, vocab_size]
recovered_probs
=
f
/
torch
.
sum
(
f
,
dim
=-
1
).
reshape
(
-
1
,
k
,
1
)
return
recovered_probs
@
cached_property
def
_smallest_positive_value
(
self
)
->
float
:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
recovered_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids
=
bonus_token_ids
.
squeeze
()
batch_size
,
k
=
recovered_token_ids
.
shape
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
recovered_token_ids
.
mul
(
after_false_mask
))
self
.
num_accepted_tokens
+=
accepted
.
sum
()
self
.
num_emitted_tokens
+=
(
output_with_bonus_tokens
!=
-
1
).
sum
()
self
.
num_draft_tokens
+=
batch_size
*
k
return
output_with_bonus_tokens
def
_raise_if_incorrect_shape
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
bonus_batch_size
,
num_bonus_tokens
=
bonus_token_ids
.
shape
draft_batch_size
,
num_draft_probs
,
draft_vocab_size
=
draft_probs
.
shape
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
assert
draft_batch_size
==
target_batch_size
assert
num_draft_probs
==
num_target_probs
assert
(
draft_vocab_size
==
target_vocab_size
),
f
"
{
draft_vocab_size
=
}
{
target_vocab_size
=
}
"
assert
draft_token_ids_batch_size
==
draft_batch_size
assert
num_draft_token_ids
==
num_draft_probs
assert
bonus_batch_size
==
target_batch_size
assert
num_bonus_tokens
==
self
.
_num_bonus_tokens
def
_raise_if_incorrect_dtype
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
all
(
probs
.
dtype
==
self
.
probs_dtype
for
probs
in
[
target_probs
,
draft_probs
])
assert
all
(
token_ids
.
dtype
==
self
.
token_id_dtype
for
token_ids
in
[
bonus_token_ids
,
draft_token_ids
])
def
_raise_if_inconsistent_device
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
devices
=
[
t
.
device
for
t
in
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
def
_raise_if_out_of_bounds_vocab
(
self
,
vocab_size
:
int
,
bonus_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
torch
.
all
(
bonus_token_ids
<
vocab_size
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@
torch
.
jit
.
script
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment