Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cd32d6f5
Unverified
Commit
cd32d6f5
authored
Mar 12, 2026
by
Nick Hill
Committed by
GitHub
Mar 13, 2026
Browse files
[Model Runner V2] Some code simplification (#36929)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
aaa3092f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
59 deletions
+17
-59
vllm/config/speculative.py
vllm/config/speculative.py
+1
-4
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+3
-13
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
+13
-42
No files found.
vllm/config/speculative.py
View file @
cd32d6f5
...
...
@@ -57,10 +57,7 @@ SpeculativeMethod = Literal[
EagleModelTypes
,
NgramGPUTypes
,
]
RejectionSampleMethod
=
Literal
[
"strict"
,
"probabilistic"
,
]
RejectionSampleMethod
=
Literal
[
"strict"
,
"probabilistic"
]
@
config
...
...
vllm/v1/worker/gpu/sample/gumbel.py
View file @
cd32d6f5
...
...
@@ -81,7 +81,7 @@ def _gumbel_sample_kernel(
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
(
temp
!=
0.0
)
and
APPLY_TEMPERATURE
:
if
temp
!=
0.0
and
APPLY_TEMPERATURE
:
# Apply temperature.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
...
...
@@ -127,18 +127,8 @@ def gumbel_sample(
num_tokens
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
local_argmax
=
torch
.
empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
int64
,
device
=
logits
.
device
,
)
local_max
=
torch
.
empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
local_argmax
=
logits
.
new_empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
int64
)
local_max
=
logits
.
new_empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
float32
)
_gumbel_sample_kernel
[(
num_tokens
,
num_blocks
)](
local_argmax
,
local_argmax
.
stride
(
0
),
...
...
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
View file @
cd32d6f5
...
...
@@ -53,17 +53,8 @@ def strict_rejection_sample(
num_speculative_steps
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
sampled
=
torch
.
empty
(
num_reqs
,
num_speculative_steps
+
1
,
dtype
=
target_sampled
.
dtype
,
device
=
target_sampled
.
device
,
)
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
target_sampled
.
device
,
)
sampled
=
target_sampled
.
new_empty
(
num_reqs
,
num_speculative_steps
+
1
)
num_sampled
=
target_sampled
.
new_empty
(
num_reqs
,
dtype
=
torch
.
int32
)
_strict_rejection_sample_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
...
...
@@ -216,12 +207,11 @@ def probabilistic_rejection_sample(
pos
:
torch
.
Tensor
,
# [num_reqs]
idx_mapping
:
torch
.
Tensor
,
temperature
,
seed
s
,
num_speculative_steps
,
temperature
:
torch
.
Tensor
,
seed
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
device
=
target_logits
.
device
vocab_size
=
target_logits
.
shape
[
-
1
]
# Compute target and draft probs.
...
...
@@ -230,18 +220,11 @@ def probabilistic_rejection_sample(
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled
=
torch
.
empty
(
num_reqs
,
num_speculative_steps
+
1
,
dtype
=
torch
.
int64
,
device
=
device
,
sampled
=
draft_sampled
.
new_empty
(
num_reqs
,
num_speculative_steps
+
1
,
dtype
=
torch
.
int64
)
# [num_reqs]
rejected_steps
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
,
)
rejected_steps
=
sampled
.
new_empty
(
num_reqs
)
_probabilistic_rejection_sample_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
...
...
@@ -255,25 +238,16 @@ def probabilistic_rejection_sample(
cu_num_logits
,
pos
,
idx_mapping
,
seed
s
,
seed
,
num_warps
=
1
,
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits
=
torch
.
empty
(
num_reqs
,
vocab_size
,
dtype
=
target_logits
.
dtype
,
device
=
device
,
)
residual_logits
=
target_logits
.
new_empty
(
num_reqs
,
vocab_size
)
# [num_reqs]
residual_pos
=
torch
.
empty
(
num_reqs
,
dtype
=
pos
.
dtype
,
device
=
device
,
)
residual_pos
=
pos
.
new_empty
(
num_reqs
)
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_compute_residual_logits_kernel
[(
num_reqs
,
num_blocks
)](
...
...
@@ -299,7 +273,7 @@ def probabilistic_rejection_sample(
residual_logits
,
idx_mapping
,
temperature
,
seed
s
,
seed
,
residual_pos
,
apply_temperature
=
False
,
)
...
...
@@ -331,10 +305,7 @@ class RejectionSampler:
num_nans
=
get_num_nans
(
logits
)
if
self
.
sampler
.
compute_nans
else
None
if
self
.
use_strict_rejection_sampling
:
sampler_output
=
self
.
sampler
(
logits
,
input_batch
,
)
sampler_output
=
self
.
sampler
(
logits
,
input_batch
)
logprobs_tensors
=
sampler_output
.
logprobs_tensors
sampled
,
num_sampled
=
strict_rejection_sample
(
sampler_output
.
sampled_token_ids
.
view
(
-
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment