Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04244fd0
Unverified
Commit
04244fd0
authored
Mar 18, 2026
by
Giancarlo Delfin
Committed by
GitHub
Mar 18, 2026
Browse files
[Model Runner V2] Spec decode rejection sampler greedy support (#37238)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@inferact.ai
>
parent
9482b0b0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
207 additions
and
71 deletions
+207
-71
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+1
-3
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
+206
-68
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
04244fd0
...
...
@@ -821,9 +821,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
,
input_batch
,
# Draft logits are needed for probabilistic rejection sampling.
self
.
req_states
.
draft_logits
[
input_batch
.
idx_mapping
]
if
self
.
req_states
.
draft_logits
is
not
None
else
None
,
self
.
req_states
.
draft_logits
,
)
# Get the number of sampled and rejected tokens.
...
...
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
View file @
04244fd0
...
...
@@ -68,55 +68,158 @@ def strict_rejection_sample(
@
triton
.
jit
def
_probabilistic_rejection_sample_kernel
(
def
_gather_draft_logits_and_target_argmax_kernel
(
local_target_argmax_ptr
,
local_target_argmax_stride
,
local_target_max_ptr
,
local_target_max_stride
,
# [num_logits, V]
out_draft_logits_ptr
,
out_draft_logits_stride
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr
,
draft_logits_stride_0
,
draft_logits_stride_1
,
# [num_logits]
expanded_idx_mapping_ptr
,
# [num_logits]
expanded_local_pos_ptr
,
# [max_num_reqs]
temp_ptr
,
vocab_size
,
num_speculative_steps
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
logit_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
logit_idx
)
draft_step_idx
=
tl
.
load
(
expanded_local_pos_ptr
+
logit_idx
)
block_idx
=
tl
.
program_id
(
1
)
block_offsets
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block_offsets
<
vocab_size
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
temp
==
0.0
:
# Greedy sampling. Get the target logits argmax.
target_logits
=
tl
.
load
(
target_logits_ptr
+
logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
value
,
idx
=
tl
.
max
(
target_logits
,
axis
=
0
,
return_indices
=
True
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
tl
.
store
(
local_target_argmax_ptr
+
logit_idx
*
local_target_argmax_stride
+
block_idx
,
token_id
,
)
tl
.
store
(
local_target_max_ptr
+
logit_idx
*
local_target_max_stride
+
block_idx
,
value
,
)
elif
draft_step_idx
<
num_speculative_steps
:
draft_logits
=
tl
.
load
(
draft_logits_ptr
+
req_state_idx
*
draft_logits_stride_0
+
draft_step_idx
*
draft_logits_stride_1
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
tl
.
store
(
out_draft_logits_ptr
+
logit_idx
*
out_draft_logits_stride
+
block_offsets
,
draft_logits
,
mask
=
mask
,
)
@
triton
.
jit
def
_probabilistic_rejection_kernel
(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr
,
sampled_stride
,
# [num_reqs]
rejected_steps_ptr
,
# [num_reqs]
rejected_pos_ptr
,
# [num_logits]
draft_sampled_ptr
,
# [num_logits, V]
target_probs_ptr
,
target_probs_stride
,
# [num_
reqs, num_speculative_step
s, V]
# [num_
logit
s, V]
draft_probs_ptr
,
draft_probs_stride_0
,
draft_probs_stride_1
,
draft_probs_stride
,
# [num_logits, num_blocks]
local_target_argmax_ptr
,
local_target_argmax_stride
,
# [num_logits, num_blocks]
local_target_max_ptr
,
local_target_max_stride
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_logits]
pos_ptr
,
# [num_reqs]
idx_mapping_ptr
,
# [num_reqs]
# [max_num_reqs]
temp_ptr
,
# [max_num_reqs]
seeds_ptr
,
NUM_BLOCKS
:
tl
.
constexpr
,
PADDED_NUM_BLOCKS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
num_tokens
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
-
start_idx
seed
=
tl
.
load
(
seeds_ptr
+
tl
.
load
(
idx_mapping_ptr
+
req_idx
))
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
seed
=
tl
.
load
(
seeds_ptr
+
req_state_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
rejected_step
=
0
accepted
=
True
for
i
in
range
(
num_tokens
-
1
):
if
accepted
:
draft_sampled
=
tl
.
load
(
draft_sampled_ptr
+
start_idx
+
i
+
1
)
logit_idx
=
start_idx
+
i
draft_sampled
=
tl
.
load
(
draft_sampled_ptr
+
logit_idx
+
1
)
if
temp
==
0.0
:
# Greedy sampling. Only accept the sampled draft token if
# it exactly matches the target argmax.
block_offsets
=
tl
.
arange
(
0
,
PADDED_NUM_BLOCKS
)
block_mask
=
block_offsets
<
NUM_BLOCKS
local_max
=
tl
.
load
(
local_target_max_ptr
+
logit_idx
*
local_target_max_stride
+
block_offsets
,
mask
=
block_mask
,
other
=
float
(
"-inf"
),
)
max_block
=
tl
.
argmax
(
local_max
,
axis
=
0
)
target_argmax
=
tl
.
load
(
local_target_argmax_ptr
+
logit_idx
*
local_target_argmax_stride
+
max_block
)
accepted
&=
target_argmax
==
draft_sampled
else
:
target_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
i
)
*
target_probs_stride
+
draft_sampled
target_probs_ptr
+
logit_idx
*
target_probs_stride
+
draft_sampled
)
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
req_idx
*
draft_probs_stride_0
+
i
*
draft_probs_stride_1
+
draft_sampled
draft_probs_ptr
+
logit_idx
*
draft_probs_stride
+
draft_sampled
)
pos
=
tl
.
load
(
pos_ptr
+
star
t_idx
+
i
)
pos
=
tl
.
load
(
pos_ptr
+
logi
t_idx
)
u
=
tl
.
sum
(
tl
.
rand
(
seed
,
pos
+
tl
.
arange
(
0
,
1
)))
accepted
&=
target_prob
>
u
*
draft_prob
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
i
,
draft_sampled
)
rejected_step
+=
accepted
tl
.
store
(
rejected_steps_ptr
+
req_idx
,
rejected_step
)
pos_val
=
tl
.
load
(
pos_ptr
+
start_idx
+
rejected_step
)
tl
.
store
(
rejected_pos_ptr
+
req_idx
,
pos_val
)
@
triton
.
jit
...
...
@@ -124,63 +227,60 @@ def _compute_residual_logits_kernel(
# [num_reqs, V]
residual_logits_ptr
,
residual_logits_stride
,
# [num_reqs]
residual_pos_ptr
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [num_logits, V]
target_probs_ptr
,
target_probs_stride
,
# [num_
reqs, num_speculative_step
s, V]
# [num_
logit
s, V]
draft_probs_ptr
,
draft_probs_stride_0
,
draft_probs_stride_1
,
draft_probs_stride
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [num_reqs]
rejected_step_ptr
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_logits]
pos_ptr
,
# [num_reqs]
idx_mapping_ptr
,
# [max_num_reqs]
temp_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
rejected_draft_step
=
tl
.
load
(
rejected_step_ptr
+
req_idx
)
rejected_logit_idx
=
start_idx
+
rejected_draft_step
rejected_logit_idx
=
start_idx
+
tl
.
load
(
rejected_step_ptr
+
req_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
block_offsets
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block_offsets
<
vocab_size
if
rejected_logit_idx
<
end_idx
-
1
:
if
temp
==
0.0
or
(
rejected_logit_idx
==
end_idx
-
1
):
# Greedy sampling / bonus token. In either case, use the
# target logits directly to reduce numerical error.
residual_logits
=
tl
.
load
(
target_logits_ptr
+
rejected_logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
else
:
target_probs
=
tl
.
load
(
target_probs_ptr
+
rejected_logit_idx
*
target_probs_stride
+
block_offsets
,
mask
=
mask
,
other
=
0.0
,
)
draft_probs
=
tl
.
load
(
draft_probs_ptr
+
req_idx
*
draft_probs_stride_0
+
rejected_draft_step
*
draft_probs_stride_1
+
block_offsets
,
draft_probs_ptr
+
rejected_logit_idx
*
draft_probs_stride
+
block_offsets
,
mask
=
mask
,
other
=
0.0
,
)
residual_probs
=
tl
.
maximum
(
target_probs
-
draft_probs
,
0.0
)
residual_logits
=
tl
.
log
(
residual_probs
)
else
:
# This is a bonus token. Directly return the target logits.
residual_logits
=
tl
.
load
(
target_logits_ptr
+
rejected_logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
0.0
,
)
tl
.
store
(
residual_logits_ptr
+
req_idx
*
residual_logits_stride
+
block_offsets
,
...
...
@@ -188,18 +288,13 @@ def _compute_residual_logits_kernel(
mask
=
mask
,
)
# First block computes the residual logit positions.
if
block_idx
==
0
:
pos_val
=
tl
.
load
(
pos_ptr
+
rejected_logit_idx
)
tl
.
store
(
residual_pos_ptr
+
req_idx
,
pos_val
)
def
probabilistic_rejection_sample
(
# [num_
draft_tokens + num_req
s, V]
# [num_
logit
s, V]
target_logits
:
torch
.
Tensor
,
# [num_reqs, num_speculative_steps, V]
# [
max_
num_reqs, num_speculative_steps, V]
draft_logits
:
torch
.
Tensor
,
# [num_
draft_tokens + num_req
s]
# [num_
logit
s]
draft_sampled
:
torch
.
Tensor
,
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
,
...
...
@@ -207,16 +302,53 @@ def probabilistic_rejection_sample(
pos
:
torch
.
Tensor
,
# [num_reqs]
idx_mapping
:
torch
.
Tensor
,
# [num_logits]
expanded_idx_mapping
:
torch
.
Tensor
,
# [num_logits]
expanded_local_pos
:
torch
.
Tensor
,
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
seed
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
vocab_size
=
target_logits
.
shape
[
-
1
]
num_logits
,
vocab_size
=
target_logits
.
shape
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
# Gather draft logits and target argmax for greedy sampling.
gathered_draft_logits
=
target_logits
.
new_empty
(
target_logits
.
shape
)
local_target_argmax
=
target_logits
.
new_empty
(
num_logits
,
num_blocks
,
dtype
=
torch
.
int64
)
local_target_max
=
target_logits
.
new_empty
(
num_logits
,
num_blocks
,
dtype
=
torch
.
float32
)
_gather_draft_logits_and_target_argmax_kernel
[(
num_logits
,
num_blocks
)](
local_target_argmax
,
local_target_argmax
.
stride
(
0
),
local_target_max
,
local_target_max
.
stride
(
0
),
gathered_draft_logits
,
gathered_draft_logits
.
stride
(
0
),
target_logits
,
target_logits
.
stride
(
0
),
draft_logits
,
draft_logits
.
stride
(
0
),
draft_logits
.
stride
(
1
),
expanded_idx_mapping
,
expanded_local_pos
,
temperature
,
vocab_size
,
num_speculative_steps
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Compute target and draft probs.
target_probs
=
torch
.
softmax
(
target_logits
,
dim
=-
1
)
draft_probs
=
torch
.
softmax
(
draft_logits
,
dim
=-
1
)
draft_probs
=
torch
.
softmax
(
gathered_
draft_logits
,
dim
=-
1
)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
...
...
@@ -225,45 +357,49 @@ def probabilistic_rejection_sample(
)
# [num_reqs]
rejected_steps
=
sampled
.
new_empty
(
num_reqs
)
_probabilistic_rejection_sample_kernel
[(
num_reqs
,)](
# [num_reqs]
rejected_pos
=
pos
.
new_empty
(
num_reqs
)
_probabilistic_rejection_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
rejected_steps
,
rejected_pos
,
draft_sampled
,
target_probs
,
target_probs
.
stride
(
0
),
draft_probs
,
draft_probs
.
stride
(
0
),
draft_probs
.
stride
(
1
),
local_target_argmax
,
local_target_argmax
.
stride
(
0
),
local_target_max
,
local_target_max
.
stride
(
0
),
cu_num_logits
,
pos
,
idx_mapping
,
temperature
,
seed
,
num_warps
=
1
,
NUM_BLOCKS
=
num_blocks
,
PADDED_NUM_BLOCKS
=
triton
.
next_power_of_2
(
num_blocks
),
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits
=
target_logits
.
new_empty
(
num_reqs
,
vocab_size
)
# [num_reqs]
residual_pos
=
pos
.
new_empty
(
num_reqs
)
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_compute_residual_logits_kernel
[(
num_reqs
,
num_blocks
)](
residual_logits
,
residual_logits
.
stride
(
0
),
residual_pos
,
target_logits
,
target_logits
.
stride
(
0
),
target_probs
,
target_probs
.
stride
(
0
),
draft_probs
,
draft_probs
.
stride
(
0
),
draft_probs
.
stride
(
1
),
target_logits
,
target_logits
.
stride
(
0
),
rejected_steps
,
cu_num_logits
,
pos
,
idx_mapping
,
temperature
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
...
...
@@ -274,7 +410,7 @@ def probabilistic_rejection_sample(
idx_mapping
,
temperature
,
seed
,
re
sidual
_pos
,
re
jected
_pos
,
apply_temperature
=
False
,
)
sampled
.
scatter_
(
1
,
rejected_steps
.
unsqueeze
(
1
),
resampled
.
unsqueeze
(
1
))
...
...
@@ -333,6 +469,8 @@ class RejectionSampler:
input_batch
.
cu_num_logits
,
pos
,
input_batch
.
idx_mapping
,
input_batch
.
expanded_idx_mapping
,
input_batch
.
expanded_local_pos
,
self
.
sampler
.
sampling_states
.
temperature
.
gpu
,
self
.
sampler
.
sampling_states
.
seeds
.
gpu
,
self
.
num_speculative_steps
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment