Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f32c7d6f
Unverified
Commit
f32c7d6f
authored
Nov 24, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 24, 2025
Browse files
[Model Runner V2] Simplify Eagle bookkeeping with num_rejected (#29347)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
3cfa63ad
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
30 deletions
+50
-30
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+6
-13
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+22
-8
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+10
-9
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
+12
-0
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
f32c7d6f
...
@@ -344,8 +344,8 @@ def _post_update_kernel(
...
@@ -344,8 +344,8 @@ def _post_update_kernel(
sampled_tokens_ptr
,
sampled_tokens_ptr
,
sampled_tokens_stride
,
sampled_tokens_stride
,
num_sampled_ptr
,
num_sampled_ptr
,
num_rejected_ptr
,
query_start_loc_ptr
,
query_start_loc_ptr
,
cu_num_logits_ptr
,
):
):
req_id
=
tl
.
program_id
(
0
)
req_id
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_id
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_id
)
...
@@ -360,17 +360,10 @@ def _post_update_kernel(
...
@@ -360,17 +360,10 @@ def _post_update_kernel(
query_start
=
tl
.
load
(
query_start_loc_ptr
+
req_id
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
req_id
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
req_id
+
1
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
req_id
+
1
)
query_len
=
query_end
-
query_start
query_len
=
query_end
-
query_start
num_rejected
=
tl
.
load
(
num_rejected_ptr
+
req_id
)
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
num_computed
+=
query_len
num_computed
+=
query_len
-
num_rejected
# Consider the rejected tokens in spec decoding.
if
num_sampled
>
0
:
# NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
logits_start
=
tl
.
load
(
cu_num_logits_ptr
+
req_id
)
logits_end
=
tl
.
load
(
cu_num_logits_ptr
+
req_id
+
1
)
num_logits
=
logits_end
-
logits_start
num_rejected
=
num_logits
-
num_sampled
num_computed
-=
num_rejected
tl
.
store
(
num_computed_tokens_ptr
+
req_state_idx
,
num_computed
)
tl
.
store
(
num_computed_tokens_ptr
+
req_state_idx
,
num_computed
)
...
@@ -385,10 +378,10 @@ def post_update(
...
@@ -385,10 +378,10 @@ def post_update(
sampled_tokens
:
torch
.
Tensor
,
sampled_tokens
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
# [num_reqs + 1]
# [num_reqs + 1]
query_start_loc
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
num_reqs
=
idx_mapping
.
shape
[
0
]
num_reqs
=
idx_mapping
.
shape
[
0
]
_post_update_kernel
[(
num_reqs
,)](
_post_update_kernel
[(
num_reqs
,)](
...
@@ -398,7 +391,7 @@ def post_update(
...
@@ -398,7 +391,7 @@ def post_update(
sampled_tokens
,
sampled_tokens
,
sampled_tokens
.
stride
(
0
),
sampled_tokens
.
stride
(
0
),
num_sampled
,
num_sampled
,
num_rejected
,
query_start_loc
,
query_start_loc
,
cu_num_logits
,
num_warps
=
1
,
num_warps
=
1
,
)
)
vllm/v1/worker/gpu/model_runner.py
View file @
f32c7d6f
...
@@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import (
)
)
from
vllm.v1.worker.gpu.sampler
import
Sampler
,
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sampler
import
Sampler
,
compute_prompt_logprobs
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
(
get_num_rejected
,
rejection_sample
,
)
from
vllm.v1.worker.gpu.states
import
RequestState
,
SamplingMetadata
from
vllm.v1.worker.gpu.states
import
RequestState
,
SamplingMetadata
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
...
@@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
num_sampled
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_sampled
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_rejected
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
propose_draft
(
self
.
propose_draft
(
input_batch
=
input_batch
,
input_batch
=
input_batch
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
last_hidden_states
=
hidden_states
,
last_hidden_states
=
hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
num_sampled
=
num_sampled
,
num_sampled
=
num_sampled
,
num_rejected
=
num_rejected
,
)
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
grammar_output
:
GrammarOutput
|
None
,
grammar_output
:
GrammarOutput
|
None
,
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
]:
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
,
torch
.
Tensor
]:
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
if
grammar_output
is
not
None
:
...
@@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No draft tokens (common case).
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
# 0 if chunked-prefilling, 1 if not.
num_sampled
=
(
~
is_chunked_prefilling
).
int
()
num_sampled
=
(
~
is_chunked_prefilling
).
int
()
num_rejected
=
torch
.
zeros_like
(
num_sampled
)
else
:
else
:
# Draft tokens for spec decoding.
# Draft tokens for spec decoding.
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
...
@@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
num_speculative_steps
,
self
.
num_speculative_steps
,
)
)
num_sampled
*=
~
is_chunked_prefilling
num_sampled
*=
~
is_chunked_prefilling
num_rejected
=
get_num_rejected
(
input_batch
.
cu_num_logits
,
num_sampled
,
)
sampler_output
.
sampled_token_ids
=
sampled_tokens
sampler_output
.
sampled_token_ids
=
sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
# TODO(woosuk): Support logprobs with spec decoding.
return
sampler_output
,
num_sampled
return
sampler_output
,
num_sampled
,
num_rejected
def
compute_prompt_logprobs
(
def
compute_prompt_logprobs
(
self
,
self
,
...
@@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
sampled_tokens
:
torch
.
Tensor
,
sampled_tokens
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
# Update the number of computed tokens.
# Update the number of computed tokens.
post_update
(
post_update
(
...
@@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
sampled_tokens
,
sampled_tokens
,
num_sampled
,
num_sampled
,
num_rejected
,
input_batch
.
query_start_loc
,
input_batch
.
query_start_loc
,
input_batch
.
cu_num_logits
,
)
)
# Update the number of computed prefill tokens.
# Update the number of computed prefill tokens.
...
@@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_hidden_states
:
torch
.
Tensor
,
last_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
num_reqs
=
input_batch
.
num_reqs
idx_mapping_np
=
input_batch
.
idx_mapping_np
idx_mapping_np
=
input_batch
.
idx_mapping_np
...
@@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_hidden_states
,
last_hidden_states
,
aux_hidden_states
,
aux_hidden_states
,
num_sampled
,
num_sampled
,
num_rejected
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
next_prefill_tokens
,
next_prefill_tokens
,
)
)
...
@@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
execute_model_state
=
None
# type: ignore
self
.
execute_model_state
=
None
# type: ignore
assert
sampling_metadata
is
not
None
assert
sampling_metadata
is
not
None
sampler_output
,
num_sampled
_tokens
=
self
.
sample
(
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
hidden_states
,
input_batch
,
sampling_metadata
,
grammar_output
hidden_states
,
input_batch
,
sampling_metadata
,
grammar_output
)
)
prompt_logprobs_dict
=
self
.
compute_prompt_logprobs
(
hidden_states
,
input_batch
)
prompt_logprobs_dict
=
self
.
compute_prompt_logprobs
(
hidden_states
,
input_batch
)
...
@@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output
=
AsyncOutput
(
async_output
=
AsyncOutput
(
model_runner_output
=
model_runner_output
,
model_runner_output
=
model_runner_output
,
sampler_output
=
sampler_output
,
sampler_output
=
sampler_output
,
num_sampled_tokens
=
num_sampled
_tokens
,
num_sampled_tokens
=
num_sampled
,
copy_stream
=
self
.
output_copy_stream
,
copy_stream
=
self
.
output_copy_stream
,
copy_event
=
self
.
output_copy_event
,
copy_event
=
self
.
output_copy_event
,
)
)
...
@@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This sequencing may slightly reduce latency as async D2H copy does not
# This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish.
# need to wait for the postprocess to finish.
self
.
postprocess
(
self
.
postprocess
(
input_batch
,
sampler_output
.
sampled_token_ids
,
num_sampled
_tokens
input_batch
,
sampler_output
.
sampled_token_ids
,
num_sampled
,
num_rejected
)
)
if
self
.
do_spec_decode
:
if
self
.
do_spec_decode
:
_
=
self
.
propose_draft
(
_
=
self
.
propose_draft
(
...
@@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata
,
sampling_metadata
,
hidden_states
,
hidden_states
,
None
,
# aux_hidden_states
None
,
# aux_hidden_states
num_sampled_tokens
,
num_sampled
,
num_rejected
,
)
)
if
self
.
use_async_scheduling
:
if
self
.
use_async_scheduling
:
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
f32c7d6f
...
@@ -60,6 +60,8 @@ class EagleSpeculator:
...
@@ -60,6 +60,8 @@ class EagleSpeculator:
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
# [num_reqs]
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
# [max_num_reqs, 1]
# [max_num_reqs, 1]
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
# [num_reqs]
...
@@ -84,6 +86,7 @@ class EagleSpeculator:
...
@@ -84,6 +86,7 @@ class EagleSpeculator:
self
.
input_ids
,
self
.
input_ids
,
input_batch
,
input_batch
,
num_sampled
,
num_sampled
,
num_rejected
,
last_sampled
,
last_sampled
,
next_prefill_tokens
,
next_prefill_tokens
,
)
)
...
@@ -139,8 +142,8 @@ def _prepare_eagle_inputs_kernel(
...
@@ -139,8 +142,8 @@ def _prepare_eagle_inputs_kernel(
last_sampled_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
next_prefill_tokens_ptr
,
num_sampled_ptr
,
num_sampled_ptr
,
num_rejected_ptr
,
query_start_loc_ptr
,
query_start_loc_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
...
@@ -149,17 +152,13 @@ def _prepare_eagle_inputs_kernel(
...
@@ -149,17 +152,13 @@ def _prepare_eagle_inputs_kernel(
query_len
=
query_end
-
query_start
query_len
=
query_end
-
query_start
# Get the true query length and next token after accounting for rejected tokens.
# Get the true query length and next token after accounting for rejected tokens.
num_rejected
=
tl
.
load
(
num_rejected_ptr
+
batch_idx
)
query_len
-=
num_rejected
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
if
num_sampled
>
0
:
if
num_sampled
>
0
:
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state_idx
).
to
(
tl
.
int32
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state_idx
).
to
(
tl
.
int32
)
logits_start
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
)
logits_end
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
+
1
)
num_logits
=
logits_end
-
logits_start
num_rejected
=
num_logits
-
num_sampled
query_len
-=
num_rejected
else
:
else
:
# Chunked prefilling.
# Chunked prefilling.
# Get the next prefill token.
# Get the next prefill token.
...
@@ -182,6 +181,8 @@ def prepare_eagle_inputs(
...
@@ -182,6 +181,8 @@ def prepare_eagle_inputs(
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
# [num_reqs]
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
# [num_reqs]
num_rejected
:
torch
.
Tensor
,
# [max_num_reqs, 1]
# [max_num_reqs, 1]
last_sampled
:
torch
.
Tensor
,
last_sampled
:
torch
.
Tensor
,
# [max_num_reqs]
# [max_num_reqs]
...
@@ -201,8 +202,8 @@ def prepare_eagle_inputs(
...
@@ -201,8 +202,8 @@ def prepare_eagle_inputs(
last_sampled
,
last_sampled
,
next_prefill_tokens
,
next_prefill_tokens
,
num_sampled
,
num_sampled
,
num_rejected
,
input_batch
.
query_start_loc
,
input_batch
.
query_start_loc
,
input_batch
.
cu_num_logits
,
BLOCK_SIZE
=
1024
,
BLOCK_SIZE
=
1024
,
)
)
return
last_token_indices
return
last_token_indices
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
View file @
f32c7d6f
...
@@ -69,3 +69,15 @@ def rejection_sample(
...
@@ -69,3 +69,15 @@ def rejection_sample(
num_warps
=
1
,
num_warps
=
1
,
)
)
return
sampled
,
num_sampled
return
sampled
,
num_sampled
@
torch
.
compile
(
dynamic
=
True
)
def
get_num_rejected
(
cu_num_logits
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_logits
=
cu_num_logits
[
1
:]
-
cu_num_logits
[:
-
1
]
num_rejected
=
num_logits
-
num_sampled
# No token is rejected for chunked prefills.
num_rejected
*=
num_sampled
>
0
return
num_rejected
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment