Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f9c38c3
Unverified
Commit
9f9c38c3
authored
Aug 01, 2025
by
Dipika Sikka
Committed by
GitHub
Aug 01, 2025
Browse files
[Speculators][Speculative Decoding] Add Qwen Eagle3 Support (#21835)
Signed-off-by:
Dipika Sikka
<
dipikasikka1@gmail.com
>
parent
a65f46be
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
11 deletions
+46
-11
tests/speculative_decoding/speculators/test_eagle3.py
tests/speculative_decoding/speculators/test_eagle3.py
+12
-2
vllm/config.py
vllm/config.py
+12
-3
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+15
-6
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+7
-0
No files found.
tests/speculative_decoding/speculators/test_eagle3.py
View file @
9f9c38c3
...
...
@@ -6,11 +6,21 @@ import torch
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"
),
(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized"
)])
[(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized"
)])
def
test_llama
(
vllm_runner
,
example_prompts
,
model_path
):
with
vllm_runner
(
model_path
,
dtype
=
torch
.
bfloat16
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
=
20
)
print
(
vllm_outputs
)
assert
vllm_outputs
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized"
)])
def
test_qwen
(
vllm_runner
,
example_prompts
,
model_path
):
with
vllm_runner
(
model_path
,
dtype
=
torch
.
bfloat16
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
=
20
)
print
(
vllm_outputs
)
assert
vllm_outputs
vllm/config.py
View file @
9f9c38c3
...
...
@@ -3175,10 +3175,19 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f
"
{
self
.
disable_by_batch_size
=
}
"
)
if
self
.
method
==
"eagle3"
and
self
.
target_model_config
and
\
"llama"
not
in
self
.
target_model_config
.
hf_text_config
.
model_type
:
from
vllm.transformers_utils.configs
import
SpeculatorsConfig
eagle3_target_supported
=
[
"llama"
]
if
self
.
draft_model_config
and
isinstance
(
self
.
draft_model_config
.
hf_config
,
SpeculatorsConfig
):
eagle3_target_supported
.
append
(
"qwen"
)
if
self
.
method
==
"eagle3"
and
self
.
target_model_config
and
not
any
(
supported_model
in
self
.
target_model_config
.
hf_text_config
.
model_type
for
supported_model
in
eagle3_target_supported
):
raise
ValueError
(
"Eagle3 is only supported for
Llama models. "
f
"Eagle3 is only supported for
{
eagle3_target_supported
}
models. "
# noqa: E501
f
"Got
{
self
.
target_model_config
.
hf_text_config
.
model_type
=
}
"
)
return
self
...
...
vllm/model_executor/models/qwen2.py
View file @
9f9c38c3
...
...
@@ -330,6 +330,8 @@ class Qwen2Model(nn.Module):
else
:
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
]
=
tuple
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -350,18 +352,25 @@ class Qwen2Model(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
aux_hidden_states
=
[]
for
idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]):
if
idx
in
self
.
aux_hidden_state_layers
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
>
0
:
return
hidden_states
,
aux_hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
...
...
vllm/model_executor/models/qwen3.py
View file @
9f9c38c3
...
...
@@ -288,6 +288,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
set_aux_hidden_state_layers
(
self
,
layers
:
tuple
[
int
])
->
None
:
self
.
model
.
aux_hidden_state_layers
=
layers
def
get_eagle3_aux_hidden_state_layers
(
self
)
->
tuple
[
int
]:
num_layers
=
len
(
self
.
model
.
layers
)
return
(
2
,
num_layers
//
2
,
num_layers
-
3
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment