Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8b346309
Unverified
Commit
8b346309
authored
Mar 13, 2026
by
Benjamin Chislett
Committed by
GitHub
Mar 13, 2026
Browse files
[Refactor] Consolidate SupportsEagle (#36063)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
54a6db82
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
15 deletions
+17
-15
vllm/model_executor/models/step1.py
vllm/model_executor/models/step1.py
+12
-12
vllm/model_executor/models/transformers/base.py
vllm/model_executor/models/transformers/base.py
+1
-1
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-1
No files found.
vllm/model_executor/models/step1.py
View file @
8b346309
...
@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsPP
from
vllm.model_executor.models.interfaces
import
(
EagleModelMixin
,
SupportsEagle
,
SupportsEagle3
,
SupportsPP
,
)
from
vllm.model_executor.models.utils
import
(
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
AutoWeightsLoader
,
PPMissingLayer
,
PPMissingLayer
,
...
@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
...
@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
return
loaded_params
return
loaded_params
class
StepDecoderModel
(
nn
.
Module
):
class
StepDecoderModel
(
nn
.
Module
,
EagleModelMixin
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
...
@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
else
:
else
:
self
.
norm
=
PPMissingLayer
()
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
,
...]
=
getattr
(
config
,
"aux_hidden_state_layers"
,
()
)
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
...
@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
aux_hidden_states
=
[]
aux_hidden_states
=
self
.
_maybe_add_hidden_state
([],
0
,
hidden_states
,
residual
)
for
idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]):
for
idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]):
if
idx
in
self
.
aux_hidden_state_layers
:
if
residual
is
None
:
aux_hidden_states
.
append
(
hidden_states
)
else
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
self
.
_maybe_add_hidden_state
(
aux_hidden_states
,
idx
+
1
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
return
IntermediateTensors
(
...
@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
...
@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
return
hidden_states
return
hidden_states
class
Step1ForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
Step1ForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsEagle
,
SupportsEagle3
):
packed_modules_mapping
=
STEP_PACKED_MODULES_MAPPING
packed_modules_mapping
=
STEP_PACKED_MODULES_MAPPING
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/transformers/base.py
View file @
8b346309
...
@@ -618,6 +618,6 @@ class Base(
...
@@ -618,6 +618,6 @@ class Base(
# Ensure that the capture hooks are installed before dynamo traces the model
# Ensure that the capture hooks are installed before dynamo traces the model
maybe_install_capturing_hooks
(
self
.
model
)
maybe_install_capturing_hooks
(
self
.
model
)
def
get_eagle3_aux_hidden_state_layers
(
self
)
->
tuple
[
int
,
...]:
def
get_eagle3_
default_
aux_hidden_state_layers
(
self
)
->
tuple
[
int
,
...]:
num_layers
=
self
.
text_config
.
num_hidden_layers
num_layers
=
self
.
text_config
.
num_hidden_layers
return
(
2
,
num_layers
//
2
,
num_layers
-
3
)
return
(
2
,
num_layers
//
2
,
num_layers
-
3
)
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
View file @
8b346309
...
@@ -27,7 +27,7 @@ def set_eagle3_aux_hidden_state_layers(
...
@@ -27,7 +27,7 @@ def set_eagle3_aux_hidden_state_layers(
if
aux_layers
:
if
aux_layers
:
logger
.
info
(
"Using Eagle3 auxiliary layers from config: %s"
,
aux_layers
)
logger
.
info
(
"Using Eagle3 auxiliary layers from config: %s"
,
aux_layers
)
else
:
else
:
aux_layers
=
eagle3_model
.
get_eagle3_aux_hidden_state_layers
()
aux_layers
=
eagle3_model
.
get_eagle3_
default_
aux_hidden_state_layers
()
logger
.
info
(
"Using Eagle3 auxiliary layers from model: %s"
,
aux_layers
)
logger
.
info
(
"Using Eagle3 auxiliary layers from model: %s"
,
aux_layers
)
eagle3_model
.
set_aux_hidden_state_layers
(
aux_layers
)
eagle3_model
.
set_aux_hidden_state_layers
(
aux_layers
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
8b346309
...
@@ -4556,7 +4556,9 @@ class GPUModelRunner(
...
@@ -4556,7 +4556,9 @@ class GPUModelRunner(
aux_layers
,
aux_layers
,
)
)
else
:
else
:
aux_layers
=
self
.
model
.
get_eagle3_aux_hidden_state_layers
()
aux_layers
=
(
self
.
model
.
get_eagle3_default_aux_hidden_state_layers
()
)
self
.
model
.
set_aux_hidden_state_layers
(
aux_layers
)
self
.
model
.
set_aux_hidden_state_layers
(
aux_layers
)
time_after_load
=
time
.
perf_counter
()
time_after_load
=
time
.
perf_counter
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment