Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf4cd6c2
Unverified
Commit
cf4cd6c2
authored
Oct 09, 2025
by
Rahul Tuli
Committed by
GitHub
Oct 09, 2025
Browse files
Add: Support for multiple hidden layers in Eagle3 (#26164)
Signed-off-by:
Rahul Tuli
<
rtuli@redhat.com
>
parent
b9604418
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
13 deletions
+29
-13
tests/speculative_decoding/speculators/test_eagle3.py
tests/speculative_decoding/speculators/test_eagle3.py
+4
-0
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+25
-13
No files found.
tests/speculative_decoding/speculators/test_eagle3.py
View file @
cf4cd6c2
...
...
@@ -22,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16"
,
id
=
"qwen3-eagle3-speculator-w4a16-verifier"
,
),
pytest
.
param
(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3"
,
id
=
"llama3-eagl3-multiple-layers"
,
),
],
)
def
test_eagle3_speculators_model
(
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
cf4cd6c2
...
...
@@ -34,15 +34,20 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
config
:
Optional
[
LlamaConfig
]
=
None
,
layer_idx
:
int
=
0
,
)
->
None
:
super
().
__init__
(
vllm_config
,
prefix
=
prefix
,
config
=
config
)
config
=
config
or
vllm_config
.
model_config
.
hf_config
quant_config
=
self
.
get_quant_config
(
vllm_config
)
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
# Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size
=
2
*
self
.
hidden_size
if
layer_idx
==
0
else
self
.
hidden_size
# override qkv
self
.
self_attn
.
qkv_proj
=
QKVParallelLinear
(
2
*
self
.
hidden
_size
,
qkv_input
_size
,
self
.
self_attn
.
head_dim
,
self
.
self_attn
.
total_num_heads
,
self
.
self_attn
.
total_num_kv_heads
,
...
...
@@ -52,6 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
)
self
.
hidden_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
layer_idx
=
layer_idx
if
getattr
(
config
,
"norm_before_residual"
,
False
):
self
.
_residual_norm
=
self
.
_norm_before_residual
...
...
@@ -90,11 +96,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
embeds
=
self
.
input_layernorm
(
embeds
)
hidden_states
,
residual
=
self
.
_residual_norm
(
hidden_states
=
hidden_states
)
if
self
.
layer_idx
==
0
:
# First layer: concatenate embeds with hidden_states
embeds
=
self
.
input_layernorm
(
embeds
)
hidden_states
,
residual
=
self
.
_residual_norm
(
hidden_states
=
hidden_states
)
hidden_states
=
torch
.
cat
([
embeds
,
hidden_states
],
dim
=-
1
)
else
:
# Subsequent layers: process hidden_states and residuals only
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
torch
.
cat
([
embeds
,
hidden_states
],
dim
=-
1
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
...
...
@@ -133,9 +143,11 @@ class LlamaModel(nn.Module):
[
LlamaDecoderLayer
(
current_vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
start_layer_id
}
"
),
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
layer_idx
+
start_layer_id
}
"
),
config
=
self
.
config
,
layer_idx
=
layer_idx
,
)
for
layer_idx
in
range
(
self
.
config
.
num_hidden_layers
)
]
)
if
hasattr
(
self
.
config
,
"target_hidden_size"
):
...
...
@@ -166,13 +178,13 @@ class LlamaModel(nn.Module):
assert
hidden_states
.
shape
[
-
1
]
==
input_embeds
.
shape
[
-
1
]
residual
=
None
hidden_states
,
residual
=
self
.
layers
[
0
](
positions
,
input_embed
s
,
hidden_state
s
,
residual
,
)
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
positions
=
position
s
,
embeds
=
input_embed
s
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
hidden_states
,
hidden_prenorm
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
,
hidden_prenorm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment