Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4047d4e
Unverified
Commit
a4047d4e
authored
Feb 21, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2026
Browse files
[Model Runner V2] Support Eagle3 (no CUDA graph) (#35029)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
965fe459
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
169 additions
and
49 deletions
+169
-49
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+67
-34
vllm/v1/worker/gpu/spec_decode/__init__.py
vllm/v1/worker/gpu/spec_decode/__init__.py
+1
-1
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
+0
-0
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
+0
-0
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
+46
-0
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+3
-14
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
+52
-0
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
a4047d4e
...
...
@@ -66,6 +66,9 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils
import
(
set_eagle3_aux_hidden_state_layers
,
)
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.spec_decode.utils
import
DraftTokensHandler
from
vllm.v1.worker.gpu.states
import
RequestState
...
...
@@ -133,14 +136,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
self
.
output_copy_event
=
torch
.
cuda
.
Event
()
# Pipeline parallelism.
self
.
pp_size
=
self
.
parallel_config
.
pipeline_parallel_size
self
.
use_pp
=
self
.
pp_size
>
1
if
self
.
use_pp
:
self
.
is_first_pp_rank
=
get_pp_group
().
is_first_rank
self
.
is_last_pp_rank
=
get_pp_group
().
is_last_rank
else
:
self
.
is_first_pp_rank
=
True
self
.
is_last_pp_rank
=
True
# Decode context parallelism.
self
.
dcp_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
use_dcp
=
self
.
dcp_size
>
1
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
if
self
.
use_dcp
else
0
self
.
cp_interleave
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
self
.
speculator
=
None
self
.
use_aux_hidden_state_outputs
=
False
if
self
.
speculative_config
is
not
None
:
self
.
do_spec_decode
=
True
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
self
.
speculator
=
init_speculator
(
self
.
vllm_config
,
self
.
device
)
if
self
.
is_last_pp_rank
:
self
.
speculator
=
init_speculator
(
self
.
vllm_config
,
self
.
device
)
if
self
.
speculative_config
.
method
==
"eagle3"
:
# EAGLE3 may require auxiliary hidden states from target model outputs.
self
.
use_aux_hidden_state_outputs
=
True
if
self
.
pp_size
>
1
:
raise
ValueError
(
"EAGLE3 with pipeline parallel is not supported."
)
else
:
self
.
do_spec_decode
=
False
self
.
num_speculative_steps
=
0
self
.
speculator
=
None
# Draft tokens propagation - for spec-dec + struct outputs.
self
.
draft_tokens_handler
=
DraftTokensHandler
(
self
.
device
)
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
...
...
@@ -176,28 +207,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# LoRA-related workers.
self
.
lora_state
=
LoraState
(
max_num_reqs
=
self
.
max_num_reqs
)
# Draft tokens propagation - for spec-dec + struct outputs.
self
.
draft_tokens_handler
=
DraftTokensHandler
(
self
.
device
)
# KV Connector if configured.
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
if
self
.
use_pp
:
self
.
is_first_pp_rank
=
get_pp_group
().
is_first_rank
self
.
is_last_pp_rank
=
get_pp_group
().
is_last_rank
else
:
self
.
is_first_pp_rank
=
True
self
.
is_last_pp_rank
=
True
# Decode context parallelism.
self
.
dcp_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
use_dcp
=
self
.
dcp_size
>
1
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
if
self
.
use_dcp
else
0
self
.
cp_interleave
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
...
...
@@ -220,7 +232,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
model
=
self
.
load_lora_model
(
self
.
model
,
self
.
vllm_config
,
self
.
device
)
if
self
.
do_spec_decode
:
if
self
.
use_aux_hidden_state_outputs
:
assert
self
.
speculative_config
is
not
None
set_eagle3_aux_hidden_state_layers
(
self
.
model
,
self
.
speculative_config
)
if
self
.
speculator
is
not
None
:
self
.
speculator
.
load_model
(
self
.
model
)
time_after_load
=
time
.
perf_counter
()
...
...
@@ -271,7 +287,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_config
,
self
.
vllm_config
,
self
.
device
)
check_attention_cp_compatibility
(
self
.
vllm_config
)
if
self
.
do_
spec
_decod
e
:
if
self
.
spec
ulator
is
not
Non
e
:
# HACK(woosuk)
self
.
speculator
.
set_attn
(
self
.
kv_cache_config
,
...
...
@@ -359,7 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
None
,
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
_
=
self
.
execute_model_state
hidden_states
,
_
,
input_batch
,
_
=
self
.
execute_model_state
assert
hidden_states
is
not
None
# Last PP rank always has hidden_states
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
return
hidden_states
,
sample_hidden_states
...
...
@@ -399,7 +415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
sample_hidden_states
is
not
None
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_
spec
_decod
e
:
if
self
.
spec
ulator
is
not
Non
e
:
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
parallel_config
.
data_parallel_size
,
self
.
max_num_tokens
)
...
...
@@ -465,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config
=
self
.
kv_cache_config
,
has_lora
=
self
.
lora_config
is
not
None
,
)
if
self
.
do_
spec
_decod
e
:
if
self
.
spec
ulator
is
not
Non
e
:
self
.
speculator
.
capture_model
()
end_time
=
time
.
perf_counter
()
...
...
@@ -964,9 +980,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
hidden_states
=
self
.
cudagraph_manager
.
run_fullgraph
(
model_output
=
self
.
cudagraph_manager
.
run_fullgraph
(
input_batch
.
num_tokens_after_padding
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
else
:
# For piecewise and eager mode, just call model().
positions
=
input_batch
.
positions
...
...
@@ -998,12 +1019,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping
=
input_batch
.
slot_mappings
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
hidden_states
=
self
.
model
(
model_output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
...
...
@@ -1011,12 +1037,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Non-last PP rank: return IntermediateTensors for sending.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
self
.
execute_model_state
=
(
None
,
input_batch
,
kv_connector_output
)
self
.
execute_model_state
=
(
None
,
None
,
input_batch
,
kv_connector_output
)
return
hidden_states
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self
.
execute_model_state
=
(
hidden_states
,
input_batch
,
kv_connector_output
)
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
self
.
execute_model_state
=
(
hidden_states
,
aux_hidden_states
,
input_batch
,
kv_connector_output
,
)
return
None
@
torch
.
inference_mode
()
...
...
@@ -1024,7 +1055,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
,
grammar_output
:
GrammarOutput
|
None
)
->
AsyncOutput
|
ModelRunnerOutput
|
None
:
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
kv_connector_output
=
self
.
execute_model_state
hidden_states
,
aux_hidden_states
,
input_batch
,
kv_connector_output
=
(
self
.
execute_model_state
)
self
.
execute_model_state
=
None
# type: ignore
if
not
self
.
is_last_pp_rank
:
...
...
@@ -1084,11 +1117,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
postprocess
(
input_batch
,
sampler_output
.
sampled_token_ids
,
num_sampled
,
num_rejected
)
if
self
.
do_
spec
_decod
e
:
if
self
.
spec
ulator
is
not
Non
e
:
draft_tokens
=
self
.
propose_draft
(
input_batch
,
hidden_states
,
None
,
#
aux_hidden_states
aux_hidden_states
,
num_sampled
,
num_rejected
,
)
...
...
vllm/v1/worker/gpu/spec_decode/__init__.py
View file @
a4047d4e
...
...
@@ -9,7 +9,7 @@ def init_speculator(vllm_config: VllmConfig, device: torch.device):
speculative_config
=
vllm_config
.
speculative_config
assert
speculative_config
is
not
None
if
speculative_config
.
use_eagle
():
from
vllm.v1.worker.gpu.spec_decode.eagle
import
EagleSpeculator
from
vllm.v1.worker.gpu.spec_decode.eagle
.speculator
import
EagleSpeculator
return
EagleSpeculator
(
vllm_config
,
device
)
raise
NotImplementedError
(
f
"
{
speculative_config
.
method
}
is not supported yet."
)
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
0 → 100644
View file @
a4047d4e
vllm/v1/worker/gpu/spec_decode/eagle
_
cudagraph.py
→
vllm/v1/worker/gpu/spec_decode/eagle
/
cudagraph.py
View file @
a4047d4e
File moved
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
0 → 100644
View file @
a4047d4e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
cast
import
torch.nn
as
nn
from
vllm.config
import
SpeculativeConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.interfaces
import
SupportsEagle3
,
supports_eagle3
logger
=
init_logger
(
__name__
)
def
set_eagle3_aux_hidden_state_layers
(
model
:
nn
.
Module
,
spec_config
:
SpeculativeConfig
,
)
->
None
:
if
not
supports_eagle3
(
model
):
raise
RuntimeError
(
"Model does not support EAGLE3 interface"
)
# mypy may infer the class-level overload for supports_eagle3.
# Narrow explicitly to the runtime protocol instance.
if
isinstance
(
model
,
type
):
raise
RuntimeError
(
"Expected model instance for EAGLE3 configuration"
)
eagle3_model
=
cast
(
SupportsEagle3
,
model
)
aux_layers
=
get_eagle3_aux_layers_from_config
(
spec_config
)
if
aux_layers
:
logger
.
info
(
"Using Eagle3 auxiliary layers from config: %s"
,
aux_layers
)
else
:
aux_layers
=
eagle3_model
.
get_eagle3_aux_hidden_state_layers
()
logger
.
info
(
"Using Eagle3 auxiliary layers from model: %s"
,
aux_layers
)
eagle3_model
.
set_aux_hidden_state_layers
(
aux_layers
)
def
get_eagle3_aux_layers_from_config
(
spec_config
:
SpeculativeConfig
,
)
->
tuple
[
int
,
...]
|
None
:
if
not
(
spec_config
and
spec_config
.
draft_model_config
):
return
None
hf_config
=
spec_config
.
draft_model_config
.
hf_config
if
not
hasattr
(
hf_config
,
"eagle_aux_hidden_state_layer_ids"
):
return
None
layer_ids
=
hf_config
.
eagle_aux_hidden_state_layer_ids
if
layer_ids
and
isinstance
(
layer_ids
,
(
list
,
tuple
)):
return
tuple
(
layer_ids
)
return
None
vllm/v1/worker/gpu/spec_decode/eagle.py
→
vllm/v1/worker/gpu/spec_decode/eagle
/speculator
.py
View file @
a4047d4e
...
...
@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
...
@@ -20,7 +19,8 @@ from vllm.v1.worker.gpu.attn_utils import (
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.spec_decode.eagle_cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.spec_decode.eagle.cudagraph
import
EagleCudaGraphManager
from
vllm.v1.worker.gpu.spec_decode.eagle.utils
import
load_eagle_model
logger
=
init_logger
(
__name__
)
...
...
@@ -73,18 +73,7 @@ class EagleSpeculator:
self
.
cudagraph_manager
=
EagleCudaGraphManager
(
vllm_config
,
device
)
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"eagle_head"
):
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
draft_model_config
)
share_lm_head
=
True
if
share_lm_head
and
hasattr
(
target_model
,
"lm_head"
):
if
hasattr
(
self
.
model
,
"lm_head"
):
del
self
.
model
.
lm_head
self
.
model
.
lm_head
=
target_model
.
lm_head
self
.
model
=
load_eagle_model
(
target_model
,
self
.
vllm_config
)
def
set_attn
(
self
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
0 → 100644
View file @
a4047d4e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.model_loader
import
get_model
def
load_eagle_model
(
target_model
:
nn
.
Module
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
from
vllm.compilation.backends
import
set_model_tag
speculative_config
=
vllm_config
.
speculative_config
assert
speculative_config
is
not
None
draft_model_config
=
speculative_config
.
draft_model_config
with
set_model_tag
(
"eagle_head"
):
eagle_model
=
get_model
(
vllm_config
=
vllm_config
,
model_config
=
draft_model_config
)
# Share target embeddings when the draft checkpoint does not include
# its own vocab embedding table.
share_embeddings
=
True
if
hasattr
(
eagle_model
,
"has_own_embed_tokens"
):
share_embeddings
=
not
eagle_model
.
has_own_embed_tokens
if
share_embeddings
:
target_language_model
=
(
target_model
.
get_language_model
()
if
hasattr
(
target_model
,
"get_language_model"
)
else
target_model
)
inner_model
=
getattr
(
target_language_model
,
"model"
,
None
)
target_embed_tokens
=
None
if
inner_model
is
not
None
:
if
hasattr
(
inner_model
,
"embed_tokens"
):
target_embed_tokens
=
inner_model
.
embed_tokens
elif
hasattr
(
inner_model
,
"embedding"
):
target_embed_tokens
=
inner_model
.
embedding
if
target_embed_tokens
is
not
None
and
hasattr
(
eagle_model
,
"model"
):
if
hasattr
(
eagle_model
.
model
,
"embed_tokens"
):
del
eagle_model
.
model
.
embed_tokens
eagle_model
.
model
.
embed_tokens
=
target_embed_tokens
# Only share target lm_head when the draft model does not own one.
share_lm_head
=
True
if
hasattr
(
eagle_model
,
"has_own_lm_head"
):
share_lm_head
=
not
eagle_model
.
has_own_lm_head
if
share_lm_head
and
hasattr
(
target_model
,
"lm_head"
):
if
hasattr
(
eagle_model
,
"lm_head"
):
del
eagle_model
.
lm_head
eagle_model
.
lm_head
=
target_model
.
lm_head
return
eagle_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment