Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ebcebeeb
Unverified
Commit
ebcebeeb
authored
Mar 24, 2025
by
Woosuk Kwon
Committed by
GitHub
Mar 24, 2025
Browse files
[V1][Spec Decode] Enable spec decode for top-p & top-k sampling (#15063)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
f533b583
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
219 additions
and
19 deletions
+219
-19
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+148
-2
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+70
-13
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+1
-4
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
ebcebeeb
...
...
@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
def
create_sampling_metadata
(
all_greedy
:
bool
,
temperature
:
Optional
[
torch
.
Tensor
]
=
None
,
top_k
:
Optional
[
torch
.
Tensor
]
=
None
,
top_p
:
Optional
[
torch
.
Tensor
]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
)
->
SamplingMetadata
:
"""Create a v1 sampling metadata object with all_greedy set
...
...
@@ -52,8 +54,8 @@ def create_sampling_metadata(
temperature
=
temperature
,
all_greedy
=
all_greedy
,
all_random
=
not
all_greedy
,
top_p
=
None
,
top_k
=
None
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
torch
.
empty
(
1
,
),
generators
=
generators
,
max_num_logprobs
=
0
,
...
...
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
density
=
True
)
return
hist
.
hist
def
_test_masked_logits
(
rejection_sampler
,
batch_size
:
int
,
num_draft_tokens
:
int
,
vocab_size
:
int
,
target_logits
:
torch
.
Tensor
,
unmasked_indices
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
):
# Set up test parameters
num_tokens
=
batch_size
*
num_draft_tokens
# Create random draft probabilities.
draft_probs
=
torch
.
rand
((
num_tokens
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
# Randomly sample draft token ids from draft probs
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
)
draft_token_ids
=
draft_token_ids
.
reshape
(
batch_size
,
num_draft_tokens
)
draft_token_ids
=
draft_token_ids
.
tolist
()
# Bonus tokens not used but required
bonus_token_ids
=
torch
.
zeros
((
batch_size
,
1
),
dtype
=
torch
.
int64
,
device
=
DEVICE
)
# Create spec decode metadata
spec_decode_metadata
=
SpecDecodeMetadata
.
make_dummy
(
draft_token_ids
,
device
=
DEVICE
,
)
# Run rejection sampling
output_token_ids
=
rejection_sampler
(
spec_decode_metadata
,
draft_probs
=
draft_probs
,
target_logits
=
target_logits
,
bonus_token_ids
=
bonus_token_ids
,
sampling_metadata
=
sampling_metadata
,
)
# Remove bonus tokens and reshape
output_token_ids
=
output_token_ids
[:,
:
-
1
].
flatten
().
tolist
()
# Check that all sampled tokens are within the unmasked indices.
for
i
in
range
(
num_tokens
):
token_id
=
output_token_ids
[
i
]
if
token_id
==
PLACEHOLDER_TOKEN_ID
:
continue
assert
token_id
in
unmasked_indices
[
i
]
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
1
,
5
,
99
])
def
test_top_k
(
rejection_sampler
,
top_k
):
"""Test rejection sampling with top-k sampling"""
vocab_size
=
100
batch_size
=
100
num_draft_tokens
=
3
num_tokens
=
batch_size
*
num_draft_tokens
# Randomly create top-k indices.
top_k_indices
=
[
torch
.
randperm
(
vocab_size
,
device
=
DEVICE
)[:
top_k
]
for
_
in
range
(
num_tokens
)
]
top_k_indices
=
torch
.
stack
(
top_k_indices
)
# Create logits with the uniform distribution.
target_logits
=
torch
.
zeros
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# sampled despite the small difference in logits.
for
i
in
range
(
num_tokens
):
target_logits
[
i
,
top_k_indices
[
i
]]
+=
0.1
# Create sampling metadata
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_k
=
torch
.
tensor
([
top_k
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
int64
),
)
_test_masked_logits
(
rejection_sampler
,
batch_size
=
batch_size
,
num_draft_tokens
=
num_draft_tokens
,
vocab_size
=
vocab_size
,
target_logits
=
target_logits
,
unmasked_indices
=
top_k_indices
,
sampling_metadata
=
sampling_metadata
,
)
@
pytest
.
mark
.
parametrize
(
"top_p"
,
[
0.5
,
0.9
,
0.99
])
def
test_top_p
(
rejection_sampler
,
top_p
):
"""Test rejection sampling with top-p sampling"""
vocab_size
=
100
batch_size
=
100
num_draft_tokens
=
3
num_tokens
=
batch_size
*
num_draft_tokens
# Create logits with the uniform distribution.
target_logits
=
torch
.
randn
((
num_tokens
,
vocab_size
),
device
=
DEVICE
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
rescaled_logits
=
target_logits
/
temperature
logits_sort
,
logits_idx
=
rescaled_logits
.
sort
(
dim
=-
1
,
descending
=
False
)
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
top_p
# at least one
top_p_mask
[:,
-
1
]
=
False
# Get the top-p indices.
top_p_indices
=
[]
for
i
in
range
(
num_tokens
):
top_p_indices
.
append
(
logits_idx
[
i
][
~
top_p_mask
[
i
]].
tolist
())
# Create sampling metadata
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
top_p
=
torch
.
tensor
([
top_p
]
*
batch_size
,
device
=
DEVICE
,
dtype
=
torch
.
float32
),
)
_test_masked_logits
(
rejection_sampler
,
batch_size
=
batch_size
,
num_draft_tokens
=
num_draft_tokens
,
vocab_size
=
vocab_size
,
target_logits
=
target_logits
,
unmasked_indices
=
top_p_indices
,
sampling_metadata
=
sampling_metadata
,
)
vllm/v1/sample/rejection_sampler.py
View file @
ebcebeeb
...
...
@@ -8,6 +8,7 @@ import triton.language as tl
from
vllm.logger
import
init_logger
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.sample.ops.utils
import
compiled_softmax
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
...
@@ -245,25 +246,81 @@ def compute_probs(
return
logits
num_tokens
=
logits
.
shape
[
0
]
batch_size
=
cu_num_draft_tokens
.
shape
[
0
]
expanded_temperature
=
torch
.
empty
(
(
num_tokens
,
1
),
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
expand_kernel
[(
batch_size
,
)](
expanded_temperature
,
temperature
=
expand_batch_to_tokens
(
sampling_metadata
.
temperature
,
cu_num_draft_tokens
,
GREEDY_TEMPERATURE
,
# replace_from
1
,
# replace_to
MAX_NUM_TOKENS
=
MAX_SPEC_LEN
,
num_warps
=
1
,
num_tokens
,
replace_from
=
GREEDY_TEMPERATURE
,
replace_to
=
1
,
)
# TODO(woosuk): Consider using in-place op to reduce memory usage.
logits
=
logits
/
temperature
.
unsqueeze
(
-
1
)
# Get expanded top_k and top_p tensors.
top_k
=
None
if
sampling_metadata
.
top_k
is
not
None
:
top_k
=
expand_batch_to_tokens
(
sampling_metadata
.
top_k
,
cu_num_draft_tokens
,
num_tokens
,
)
top_p
=
None
if
sampling_metadata
.
top_p
is
not
None
:
top_p
=
expand_batch_to_tokens
(
sampling_metadata
.
top_p
,
cu_num_draft_tokens
,
num_tokens
,
)
output_prob
=
compiled_softmax
(
logits
,
expanded_temperature
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
output_prob
=
compiled_softmax
(
logits
)
return
output_prob
def
expand_batch_to_tokens
(
x
:
torch
.
Tensor
,
# [batch_size]
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size]
num_tokens
:
int
,
replace_from
:
int
=
0
,
replace_to
:
int
=
0
,
)
->
torch
.
Tensor
:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size
=
x
.
shape
[
0
]
assert
cu_num_tokens
.
shape
[
0
]
==
batch_size
expanded_x
=
x
.
new_empty
(
num_tokens
)
expand_kernel
[(
batch_size
,
)](
expanded_x
,
x
,
cu_num_tokens
,
replace_from
,
replace_to
,
MAX_NUM_TOKENS
=
MAX_SPEC_LEN
,
# To avoid recompilation.
num_warps
=
1
,
)
return
expanded_x
def
generate_uniform_probs
(
num_tokens
:
int
,
num_draft_tokens
:
list
[
int
],
...
...
vllm/v1/spec_decode/utils.py
View file @
ebcebeeb
...
...
@@ -3,10 +3,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch
def
is_spec_decode_supported
(
req_id
:
str
,
input_batch
:
InputBatch
)
->
bool
:
if
req_id
in
input_batch
.
top_k_reqs
or
req_id
in
input_batch
.
top_p_reqs
:
# Spec decode doesn't support top_p/top_k sampling.
return
False
elif
req_id
in
input_batch
.
min_p_reqs
:
if
req_id
in
input_batch
.
min_p_reqs
:
# Spec decode doesn't support min_p sampling.
return
False
elif
(
req_id
in
input_batch
.
frequency_penalties_reqs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment