Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc313cb7
Unverified
Commit
cc313cb7
authored
Nov 24, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 24, 2025
Browse files
[Model Runner V2] Implement Single-step Eagle 1 (#29300)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
26a46558
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
300 additions
and
2 deletions
+300
-2
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+3
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+79
-0
vllm/v1/worker/gpu/sampler.py
vllm/v1/worker/gpu/sampler.py
+3
-2
vllm/v1/worker/gpu/spec_decode/__init__.py
vllm/v1/worker/gpu/spec_decode/__init__.py
+18
-0
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+197
-0
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
cc313cb7
...
...
@@ -37,6 +37,9 @@ class InputBuffers:
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
# Spec decoding.
self
.
next_prefill_tokens
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
# Structured outputs.
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
grammar_bitmask
=
self
.
_make_buffer
(
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
cc313cb7
...
...
@@ -45,6 +45,7 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_prefill_inputs
,
)
from
vllm.v1.worker.gpu.sampler
import
Sampler
,
compute_prompt_logprobs
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.states
import
RequestState
,
SamplingMetadata
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
...
...
@@ -97,16 +98,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
self
.
use_async_scheduling
:
self
.
input_prep_event
=
torch
.
cuda
.
Event
()
self
.
structured_outputs_event
=
torch
.
cuda
.
Event
()
self
.
spec_decode_event
=
torch
.
cuda
.
Event
()
else
:
self
.
input_prep_event
=
None
self
.
structured_outputs_event
=
None
self
.
spec_decode_event
=
None
if
self
.
speculative_config
is
not
None
:
self
.
do_spec_decode
=
True
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
self
.
speculator
=
init_speculator
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
do_spec_decode
=
False
self
.
num_speculative_steps
=
0
self
.
speculator
=
None
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
...
...
@@ -153,6 +158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
vllm_config
,
self
.
device
,
)
if
self
.
do_spec_decode
:
self
.
speculator
.
load_model
(
self
.
model
)
time_after_load
=
time
.
perf_counter
()
self
.
model_memory_usage
=
m
.
consumed_memory
...
...
@@ -285,6 +292,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
hidden_states
)
self
.
sampler
(
logits
,
sampling_metadata
)
@
torch
.
inference_mode
()
def
_dummy_speculator_run
(
self
,
hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
)
->
None
:
num_tokens
=
hidden_states
.
shape
[
0
]
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_batch
=
InputBatch
.
make_dummy
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
input_buffers
=
self
.
input_buffers
,
device
=
self
.
device
,
)
sampling_metadata
=
SamplingMetadata
.
make_dummy
(
num_reqs
=
num_reqs
,
device
=
self
.
device
,
)
num_sampled
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
propose_draft
(
input_batch
=
input_batch
,
sampling_metadata
=
sampling_metadata
,
last_hidden_states
=
hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
num_sampled
=
num_sampled
,
)
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
...
...
@@ -292,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
skip_attn
=
True
,
)
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_spec_decode
:
self
.
_dummy_speculator_run
(
hidden_states
,
None
)
torch
.
cuda
.
synchronize
()
del
hidden_states
,
sample_hidden_states
gc
.
collect
()
...
...
@@ -727,6 +763,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
req_states
.
prefill_len
.
np
[
idx_mapping_np
],
)
@
torch
.
inference_mode
()
def
propose_draft
(
self
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
last_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
num_sampled
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
idx_mapping_np
=
input_batch
.
idx_mapping_np
with
async_barrier
(
self
.
spec_decode_event
):
self
.
input_buffers
.
next_prefill_tokens
.
np
[:
num_reqs
]
=
(
self
.
req_states
.
prefill_token_ids
[
idx_mapping_np
,
self
.
req_states
.
num_computed_prefill_tokens
[
idx_mapping_np
],
]
)
next_prefill_tokens
=
self
.
input_buffers
.
next_prefill_tokens
.
copy_to_gpu
(
num_reqs
)
assert
self
.
speculator
is
not
None
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
sampling_metadata
,
last_hidden_states
,
aux_hidden_states
,
num_sampled
,
self
.
req_states
.
last_sampled_tokens
,
next_prefill_tokens
,
)
self
.
req_states
.
draft_tokens
[
input_batch
.
idx_mapping
]
=
draft_tokens
return
draft_tokens
def
get_cudagraph_and_dp_padding
(
self
,
scheduler_output
:
SchedulerOutput
,
...
...
@@ -913,6 +984,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
postprocess
(
input_batch
,
sampler_output
.
sampled_token_ids
,
num_sampled_tokens
)
if
self
.
do_spec_decode
:
_
=
self
.
propose_draft
(
input_batch
,
sampling_metadata
,
hidden_states
,
None
,
# aux_hidden_states
num_sampled_tokens
,
)
if
self
.
use_async_scheduling
:
return
async_output
...
...
vllm/v1/worker/gpu/sampler.py
View file @
cc313cb7
...
...
@@ -100,8 +100,9 @@ def _gumbel_sample_kernel(
mask
=
mask
,
other
=
float
(
"-inf"
),
)
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
.
to
(
tl
.
float32
)
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
...
...
@@ -116,7 +117,7 @@ def _gumbel_sample_kernel(
# Apply temperature.
if
APPLY_TEMPERATURE
:
# NOTE(woosuk): Use div_rn to match the behavior of torch.
logits
=
tl
.
div_rn
(
logits
,
temp
.
to
(
tl
.
float32
)
)
logits
=
tl
.
div_rn
(
logits
,
temp
)
# Apply gumbel noise.
logits
=
tl
.
where
(
mask
,
logits
+
gumbel_noise
,
float
(
"-inf"
))
...
...
vllm/v1/worker/gpu/spec_decode/__init__.py
View file @
cc313cb7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.config
import
VllmConfig
def
init_speculator
(
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
speculative_config
=
vllm_config
.
speculative_config
assert
speculative_config
is
not
None
if
speculative_config
.
use_eagle
():
from
vllm.v1.worker.gpu.spec_decode.eagle
import
EagleSpeculator
return
EagleSpeculator
(
vllm_config
,
device
)
raise
NotImplementedError
(
f
"
{
speculative_config
.
method
}
is not supported yet."
)
vllm/v1/worker/gpu/spec_decode/eagle.py
0 → 100644
View file @
cc313cb7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.sampler
import
gumbel_sample
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
class
EagleSpeculator
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
self
.
vllm_config
=
vllm_config
self
.
device
=
device
self
.
speculative_config
=
vllm_config
.
speculative_config
assert
self
.
speculative_config
is
not
None
self
.
method
=
self
.
speculative_config
.
method
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
self
.
draft_model_config
=
self
.
speculative_config
.
draft_model_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
input_ids
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"eagle_head"
):
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
draft_model_config
)
share_lm_head
=
True
if
share_lm_head
and
hasattr
(
target_model
,
"lm_head"
):
if
hasattr
(
self
.
model
,
"lm_head"
):
del
self
.
model
.
lm_head
self
.
model
.
lm_head
=
target_model
.
lm_head
@
torch
.
inference_mode
()
def
propose
(
self
,
input_batch
:
InputBatch
,
sampling_metadata
:
SamplingMetadata
,
# [num_tokens, hidden_size]
last_hidden_states
:
torch
.
Tensor
,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
# [max_num_reqs, 1]
last_sampled
:
torch
.
Tensor
,
# [num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
aux_hidden_states
:
assert
self
.
method
==
"eagle3"
hidden_states
=
self
.
model
.
combine_hidden_states
(
torch
.
cat
(
aux_hidden_states
,
dim
=-
1
)
)
else
:
hidden_states
=
last_hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices
=
prepare_eagle_inputs
(
self
.
input_ids
,
input_batch
,
num_sampled
,
last_sampled
,
next_prefill_tokens
,
)
input_ids
=
self
.
input_ids
[:
input_batch
.
num_tokens_after_padding
]
# Prefill: Run the eagle speculator with eager mode.
with
set_forward_context
(
input_batch
.
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
):
ret_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
input_batch
.
positions
,
hidden_states
=
hidden_states
,
)
if
self
.
method
==
"mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
cu_num_logits
=
input_batch
.
cu_num_logits
[:
num_reqs
]
temperature
=
sampling_metadata
.
temperature
[
cu_num_logits
]
seed
=
sampling_metadata
.
seeds
[
cu_num_logits
]
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
pos
=
input_batch
.
positions
[
last_token_indices
]
+
1
draft_tokens
=
gumbel_sample
(
logits
,
temperature
,
seed
,
pos
,
apply_temperature
=
True
)
if
self
.
num_speculative_steps
==
1
:
# Early exit.
return
draft_tokens
.
view
(
-
1
,
1
)
raise
NotImplementedError
(
"num_speculative_steps > 1 is not supported yet."
)
@
triton
.
jit
def
_prepare_eagle_inputs_kernel
(
last_token_indices_ptr
,
eagle_input_ids_ptr
,
target_input_ids_ptr
,
idx_mapping_ptr
,
last_sampled_ptr
,
next_prefill_tokens_ptr
,
num_sampled_ptr
,
query_start_loc_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_idx
+
1
)
query_len
=
query_end
-
query_start
# Get the true query length and next token after accounting for rejected tokens.
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
batch_idx
)
if
num_sampled
>
0
:
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
next_token
=
tl
.
load
(
last_sampled_ptr
+
req_state_idx
).
to
(
tl
.
int32
)
logits_start
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
)
logits_end
=
tl
.
load
(
cu_num_logits_ptr
+
batch_idx
+
1
)
num_logits
=
logits_end
-
logits_start
num_rejected
=
num_logits
-
num_sampled
query_len
-=
num_rejected
else
:
# Chunked prefilling.
# Get the next prefill token.
next_token
=
tl
.
load
(
next_prefill_tokens_ptr
+
batch_idx
)
# Shift target_input_ids by one.
for
i
in
range
(
1
,
query_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
query_len
input_ids
=
tl
.
load
(
target_input_ids_ptr
+
query_start
+
block
,
mask
=
mask
)
tl
.
store
(
eagle_input_ids_ptr
+
query_start
+
block
-
1
,
input_ids
,
mask
=
mask
)
last_token_index
=
query_start
+
query_len
-
1
tl
.
store
(
last_token_indices_ptr
+
batch_idx
,
last_token_index
)
tl
.
store
(
eagle_input_ids_ptr
+
last_token_index
,
next_token
)
def
prepare_eagle_inputs
(
eagle_input_ids
:
torch
.
Tensor
,
input_batch
:
InputBatch
,
# [num_reqs]
num_sampled
:
torch
.
Tensor
,
# [max_num_reqs, 1]
last_sampled
:
torch
.
Tensor
,
# [max_num_reqs]
next_prefill_tokens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_reqs
=
input_batch
.
num_reqs
last_token_indices
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
eagle_input_ids
.
device
,
)
_prepare_eagle_inputs_kernel
[(
num_reqs
,)](
last_token_indices
,
eagle_input_ids
,
input_batch
.
input_ids
,
input_batch
.
idx_mapping
,
last_sampled
,
next_prefill_tokens
,
num_sampled
,
input_batch
.
query_start_loc
,
input_batch
.
cu_num_logits
,
BLOCK_SIZE
=
1024
,
)
return
last_token_indices
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment