Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5a4b4b37
Unverified
Commit
5a4b4b37
authored
Aug 12, 2025
by
Rahul Tuli
Committed by
GitHub
Aug 12, 2025
Browse files
Add: `SupportsEagle3` interface for explicit EAGLE3 support (#22642)
Signed-off-by:
Rahul Tuli
<
rtuli@redhat.com
>
parent
e5d3d63c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
81 additions
and
8 deletions
+81
-8
tests/speculative_decoding/speculators/test_eagle3.py
tests/speculative_decoding/speculators/test_eagle3.py
+16
-2
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+53
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-2
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+2
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-2
No files found.
tests/speculative_decoding/speculators/test_eagle3.py
View file @
5a4b4b37
...
...
@@ -3,12 +3,20 @@
import
pytest
import
torch
from
vllm.model_executor.models.interfaces
import
supports_eagle3
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized"
)])
def
test_llama
(
vllm_runner
,
example_prompts
,
model_path
):
def
test_llama
(
vllm_runner
,
example_prompts
,
model_path
,
monkeypatch
):
# Set environment variable for V1 engine serialization
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
with
vllm_runner
(
model_path
,
dtype
=
torch
.
bfloat16
)
as
vllm_model
:
eagle3_supported
=
vllm_model
.
apply_model
(
supports_eagle3
)
assert
eagle3_supported
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
=
20
)
print
(
vllm_outputs
)
...
...
@@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized"
)])
def
test_qwen
(
vllm_runner
,
example_prompts
,
model_path
):
def
test_qwen
(
vllm_runner
,
example_prompts
,
model_path
,
monkeypatch
):
# Set environment variable for V1 engine serialization
monkeypatch
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
with
vllm_runner
(
model_path
,
dtype
=
torch
.
bfloat16
)
as
vllm_model
:
eagle3_supported
=
vllm_model
.
apply_model
(
supports_eagle3
)
assert
eagle3_supported
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
=
20
)
print
(
vllm_outputs
)
...
...
vllm/model_executor/models/interfaces.py
View file @
5a4b4b37
...
...
@@ -823,3 +823,56 @@ def supports_v0_only(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsV0Only
]],
TypeIs
[
SupportsV0Only
]]:
return
getattr
(
model
,
"supports_v0_only"
,
False
)
@
runtime_checkable
class
SupportsEagle3
(
Protocol
):
"""The interface required for models that support
EAGLE3 speculative decoding."""
supports_eagle3
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports EAGLE3
speculative decoding.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def
set_aux_hidden_state_layers
(
self
,
layers
:
tuple
[
int
,
...])
->
None
:
"""
Set which layers should output auxiliary
hidden states for EAGLE3.
Args:
layers: Tuple of layer indices that should output auxiliary
hidden states.
"""
...
def
get_eagle3_aux_hidden_state_layers
(
self
)
->
tuple
[
int
,
...]:
"""
Get the layer indices that should output auxiliary hidden states
for EAGLE3.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.
"""
...
@
overload
def
supports_eagle3
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsEagle3
]]:
...
@
overload
def
supports_eagle3
(
model
:
object
)
->
TypeIs
[
SupportsEagle3
]:
...
def
supports_eagle3
(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsEagle3
]],
TypeIs
[
SupportsEagle3
]]:
return
isinstance
(
model
,
SupportsEagle3
)
vllm/model_executor/models/llama.py
View file @
5a4b4b37
...
...
@@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsEagle3
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
...
...
@@ -463,7 +463,7 @@ class LlamaModel(nn.Module):
return
loaded_params
class
LlamaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
LlamaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsEagle3
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
...
...
vllm/model_executor/models/qwen3.py
View file @
5a4b4b37
...
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsEagle3
,
SupportsLoRA
,
SupportsPP
from
.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
.qwen2
import
Qwen2Model
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
...
...
@@ -261,7 +261,7 @@ class Qwen3Model(Qwen2Model):
decoder_layer_type
=
Qwen3DecoderLayer
)
class
Qwen3ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
Qwen3ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsEagle3
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
5a4b4b37
...
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.interfaces
import
(
is_mixture_of_experts
,
supports_eagle3
,
supports_transcription
)
from
vllm.model_executor.models.interfaces_base
import
(
VllmModelForPooling
,
is_pooling_model
,
is_text_generation_model
)
...
...
@@ -1981,8 +1982,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger
.
info
(
"Loading drafter model..."
)
self
.
drafter
.
load_model
(
self
.
model
)
if
self
.
use_aux_hidden_state_outputs
:
self
.
model
.
set_aux_hidden_state_layers
(
self
.
model
.
get_eagle3_aux_hidden_state_layers
())
if
supports_eagle3
(
self
.
model
):
self
.
model
.
set_aux_hidden_state_layers
(
self
.
model
.
get_eagle3_aux_hidden_state_layers
())
else
:
raise
RuntimeError
(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested"
)
time_after_load
=
time
.
perf_counter
()
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Model loading took %.4f GiB and %.6f seconds"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment