Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8980001c
Unverified
Commit
8980001c
authored
Jan 31, 2026
by
caozuoba
Committed by
GitHub
Jan 31, 2026
Browse files
[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling (#32852)
Signed-off-by:
hdj
<
1293066020@qq.com
>
parent
527bcd14
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+11
-10
No files found.
vllm/v1/sample/rejection_sampler.py
View file @
8980001c
...
...
@@ -136,8 +136,6 @@ class RejectionSampler(nn.Module):
metadata
.
cu_num_draft_tokens
,
sampling_metadata
,
)
# Compute probability distribution from target logits.
target_probs
=
target_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
output_token_ids
=
rejection_sample
(
metadata
.
draft_token_ids
,
...
...
@@ -145,7 +143,7 @@ class RejectionSampler(nn.Module):
metadata
.
max_spec_len
,
metadata
.
cu_num_draft_tokens
,
draft_probs
,
target_
prob
s
,
target_
logit
s
,
bonus_token_ids
,
sampling_metadata
,
)
...
...
@@ -353,7 +351,7 @@ def rejection_sample(
# [num_tokens, vocab_size]
draft_probs
:
torch
.
Tensor
|
None
,
# [num_tokens, vocab_size]
target_
prob
s
:
torch
.
Tensor
,
target_
logit
s
:
torch
.
Tensor
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
...
...
@@ -361,17 +359,16 @@ def rejection_sample(
assert
draft_token_ids
.
ndim
==
1
assert
draft_probs
is
None
or
draft_probs
.
ndim
==
2
assert
cu_num_draft_tokens
.
ndim
==
1
assert
target_
prob
s
.
ndim
==
2
assert
target_
logit
s
.
ndim
==
2
batch_size
=
len
(
num_draft_tokens
)
num_tokens
=
draft_token_ids
.
shape
[
0
]
vocab_size
=
target_
prob
s
.
shape
[
-
1
]
device
=
target_
prob
s
.
device
vocab_size
=
target_
logit
s
.
shape
[
-
1
]
device
=
target_
logit
s
.
device
assert
draft_token_ids
.
is_contiguous
()
assert
draft_probs
is
None
or
draft_probs
.
is_contiguous
()
assert
target_probs
.
is_contiguous
()
assert
bonus_token_ids
.
is_contiguous
()
assert
target_
prob
s
.
shape
==
(
num_tokens
,
vocab_size
)
assert
target_
logit
s
.
shape
==
(
num_tokens
,
vocab_size
)
# Create output buffer.
output_token_ids
=
torch
.
full
(
...
...
@@ -387,7 +384,7 @@ def rejection_sample(
is_greedy
=
sampling_metadata
.
temperature
==
GREEDY_TEMPERATURE
if
not
sampling_metadata
.
all_random
:
# Rejection sampling for greedy sampling requests.
target_argmax
=
target_
prob
s
.
argmax
(
dim
=-
1
)
target_argmax
=
target_
logit
s
.
argmax
(
dim
=-
1
)
rejection_greedy_sample_kernel
[(
batch_size
,)](
output_token_ids
,
cu_num_draft_tokens
,
...
...
@@ -400,6 +397,10 @@ def rejection_sample(
if
sampling_metadata
.
all_greedy
:
return
output_token_ids
# Compute probability distribution from target logits.
target_probs
=
target_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
assert
target_probs
.
is_contiguous
()
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs
=
generate_uniform_probs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment