Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a5adad4
Unverified
Commit
7a5adad4
authored
Feb 20, 2026
by
Xin Yang
Committed by
GitHub
Feb 20, 2026
Browse files
[Kernel] Optimize sample_recovered_tokens_kernel (#34974)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
59c62332
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
178 additions
and
33 deletions
+178
-33
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+126
-1
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+52
-32
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
7a5adad4
...
...
@@ -11,7 +11,11 @@ from tests.v1.sample.utils import create_allowed_token_ids
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.logits_processor
import
LogitsProcessors
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
PLACEHOLDER_TOKEN_ID
,
RejectionSampler
from
vllm.v1.sample.rejection_sampler
import
(
PLACEHOLDER_TOKEN_ID
,
RejectionSampler
,
sample_recovered_tokens
,
)
from
vllm.v1.sample.sampler
import
Sampler
,
SamplerOutput
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
...
@@ -518,6 +522,70 @@ def estimate_rejection_sampling_pdf(
return
hist
.
hist
def
native_sample_recovered_tokens
(
max_spec_len
:
int
,
num_draft_tokens
:
list
[
int
],
cu_num_draft_tokens
:
torch
.
Tensor
,
# [batch_size]
draft_token_ids
:
torch
.
Tensor
,
# [num_tokens]
draft_probs
:
torch
.
Tensor
|
None
,
# [num_tokens, vocab_size]
target_probs
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
sampling_metadata
:
SamplingMetadata
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
batch_size
=
len
(
num_draft_tokens
)
vocab_size
=
target_probs
.
shape
[
-
1
]
q
=
torch
.
empty
(
(
batch_size
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
,
)
q
.
exponential_
()
states
=
{
i
:
generator
.
get_state
()
for
i
,
generator
in
sampling_metadata
.
generators
.
items
()
}
for
i
,
generator
in
sampling_metadata
.
generators
.
items
():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if
num_draft_tokens
[
i
]
>
0
:
q
[
i
].
exponential_
(
generator
=
generator
)
# In order to generate the same exponential later, reset the CUDA RNG
# state because RNG state advances after each call.
generator
.
set_state
(
states
[
i
])
inv_q
=
q
.
reciprocal
()
out
=
torch
.
empty_like
(
draft_token_ids
)
for
req_idx
in
range
(
batch_size
):
start_idx
=
0
if
req_idx
==
0
else
int
(
cu_num_draft_tokens
[
req_idx
-
1
].
item
())
end_idx
=
int
(
cu_num_draft_tokens
[
req_idx
].
item
())
num_tokens
=
end_idx
-
start_idx
for
pos
in
range
(
max_spec_len
):
if
pos
>=
num_tokens
:
continue
token_idx
=
start_idx
+
pos
if
draft_probs
is
None
:
# prob is target_probs[token_idx] except draft_token_id is zeroed
prob
=
target_probs
[
token_idx
].
clone
()
draft_token_id
=
draft_token_ids
[
token_idx
]
prob
[
draft_token_id
]
=
0.0
else
:
prob
=
(
target_probs
[
token_idx
]
-
draft_probs
[
token_idx
]).
clamp_min_
(
0.0
)
score
=
prob
*
inv_q
[
req_idx
]
recovered_id
=
torch
.
argmax
(
score
,
dim
=-
1
)
out
[
token_idx
]
=
recovered_id
return
out
def
_test_masked_logits
(
rejection_sampler
,
batch_size
:
int
,
...
...
@@ -778,3 +846,60 @@ def test_allowed_token_ids(rejection_sampler):
device
=
logits
.
device
,
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
100
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
100
,
8192
,
10000
])
@
pytest
.
mark
.
parametrize
(
"max_spec_len"
,
[
1
,
3
])
@
pytest
.
mark
.
parametrize
(
"no_draft_probs"
,
[
True
,
False
])
def
test_sample_recovered_tokens
(
batch_size
:
int
,
vocab_size
:
int
,
max_spec_len
:
int
,
no_draft_probs
:
bool
):
num_tokens
=
batch_size
*
max_spec_len
# Create random draft probabilities.
draft_probs
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
draft_probs
=
F
.
softmax
(
draft_probs
,
dim
=-
1
)
# Create random target probabilities.
target_logits
=
torch
.
rand
(
num_tokens
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
target_probs
=
F
.
softmax
(
target_logits
,
dim
=-
1
)
# Randomly sample draft token ids from draft probs
draft_token_ids
=
torch
.
multinomial
(
draft_probs
,
num_samples
=
1
).
to
(
torch
.
int32
)
temperature
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
float32
,
device
=
DEVICE
)
generators
=
{
i
:
torch
.
Generator
(
device
=
DEVICE
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
}
sampling_metadata
=
create_sampling_metadata
(
all_greedy
=
False
,
temperature
=
temperature
,
generators
=
generators
)
spec_decode_metadata
=
create_spec_decode_metadata
(
draft_token_ids
.
reshape
(
batch_size
,
max_spec_len
).
tolist
(),
target_logits
)
ref_recovered_token_ids
=
native_sample_recovered_tokens
(
max_spec_len
,
spec_decode_metadata
.
num_draft_tokens
,
spec_decode_metadata
.
cu_num_draft_tokens
,
draft_token_ids
,
None
if
no_draft_probs
else
draft_probs
,
target_probs
,
sampling_metadata
,
device
=
DEVICE
,
)
recovered_token_ids
=
sample_recovered_tokens
(
max_spec_len
,
spec_decode_metadata
.
num_draft_tokens
,
spec_decode_metadata
.
cu_num_draft_tokens
,
draft_token_ids
,
None
if
no_draft_probs
else
draft_probs
,
target_probs
,
sampling_metadata
,
device
=
DEVICE
,
)
assert
torch
.
equal
(
recovered_token_ids
,
ref_recovered_token_ids
)
vllm/v1/sample/rejection_sampler.py
View file @
7a5adad4
...
...
@@ -623,16 +623,19 @@ def sample_recovered_tokens(
if
num_draft_tokens
[
i
]
>
0
:
q
[
i
].
exponential_
(
generator
=
generator
)
inv_q
=
q
.
reciprocal
()
recovered_token_ids
=
torch
.
empty_like
(
draft_token_ids
)
BLOCK_SIZE
=
8192
sample_recovered_tokens_kernel
[(
batch_size
,
max_spec_len
)](
recovered_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
q
,
inv_
q
,
vocab_size
,
triton
.
next_power_of_2
(
vocab_size
)
,
BLOCK_SIZE
,
NO_DRAFT_PROBS
=
draft_probs
is
None
,
)
return
recovered_token_ids
...
...
@@ -776,9 +779,9 @@ def sample_recovered_tokens_kernel(
draft_token_ids_ptr
,
# [num_tokens]
draft_probs_ptr
,
# [num_tokens, vocab_size] or None
target_probs_ptr
,
# [num_tokens, vocab_size]
q_ptr
,
# [batch_size, vocab_size]
inv_
q_ptr
,
# [batch_size, vocab_size]
vocab_size
,
PADDED_VOCAB
_SIZE
:
tl
.
constexpr
,
BLOCK
_SIZE
:
tl
.
constexpr
,
NO_DRAFT_PROBS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
...
...
@@ -791,33 +794,50 @@ def sample_recovered_tokens_kernel(
if
pos
>=
num_draft_tokens
:
return
vocab_offset
=
tl
.
arange
(
0
,
PADDED_VOCAB_SIZE
)
token_idx
=
start_idx
+
pos
if
NO_DRAFT_PROBS
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
vocab_offset
,
mask
=
((
vocab_offset
<
vocab_size
)
&
(
vocab_offset
!=
draft_token_id
)),
other
=
0
,
)
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
0
,
)
target_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
0
,
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
token_idx
)
max_val
=
float
(
"-inf"
)
recovered_id
=
0
for
v
in
range
(
0
,
vocab_size
,
BLOCK_SIZE
):
vocab_offset
=
v
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
vocab_mask
=
vocab_offset
<
vocab_size
if
NO_DRAFT_PROBS
:
prob
=
tl
.
load
(
target_probs_ptr
+
token_idx
*
vocab_size
+
vocab_offset
,
mask
=
(
vocab_mask
&
(
vocab_offset
!=
draft_token_id
)),
other
=
0.0
,
)
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
token_idx
*
vocab_size
+
vocab_offset
,
mask
=
vocab_mask
,
other
=
0.0
,
)
target_prob
=
tl
.
load
(
target_probs_ptr
+
token_idx
*
vocab_size
+
vocab_offset
,
mask
=
vocab_mask
,
other
=
0.0
,
)
prob
=
tl
.
maximum
(
target_prob
-
draft_prob
,
0.0
)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
inv_q
=
tl
.
load
(
inv_q_ptr
+
req_idx
*
vocab_size
+
vocab_offset
,
mask
=
vocab_mask
,
other
=
0.0
,
)
prob
=
tl
.
maximum
(
target_prob
-
draft_prob
,
0
)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q
=
tl
.
load
(
q_ptr
+
req_idx
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
float
(
"-inf"
),
)
recovered_id
=
tl
.
argmax
(
prob
/
q
,
axis
=-
1
)
tl
.
store
(
output_token_ids_ptr
+
start_idx
+
pos
,
recovered_id
)
# Local tile reduction
score
=
prob
*
inv_q
local_max
,
local_id
=
tl
.
max
(
score
,
axis
=
0
,
return_indices
=
True
)
if
local_max
>
max_val
:
max_val
=
local_max
recovered_id
=
v
+
local_id
tl
.
store
(
output_token_ids_ptr
+
token_idx
,
recovered_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment