Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
78c43d88
Unverified
Commit
78c43d88
authored
Oct 30, 2025
by
JensenFire
Committed by
GitHub
Oct 30, 2025
Browse files
[Feature] Initial eagle3 support for Deepseek-like models (#12319)
parent
7e28c67d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
4 deletions
+27
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+27
-4
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
78c43d88
...
...
@@ -21,7 +21,7 @@ import concurrent.futures
import
logging
import
os
from
enum
import
IntEnum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
...
...
@@ -2841,6 +2841,7 @@ class DeepseekV2Model(nn.Module):
self
.
embed_tokens
.
embedding_dim
,
)
)
self
.
layers_to_capture
=
[]
def
get_input_embeddings
(
self
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
...
...
@@ -2897,9 +2898,11 @@ class DeepseekV2Model(nn.Module):
normal_end_layer
=
self
.
first_k_dense_replace
elif
self
.
first_k_dense_replace
<
normal_start_layer
:
normal_end_layer
=
normal_start_layer
=
0
aux_hidden_states
=
[]
for
i
in
range
(
normal_start_layer
,
normal_end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
...
...
@@ -2937,7 +2940,9 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
,
aux_hidden_states
class
DeepseekV2ForCausalLM
(
nn
.
Module
):
...
...
@@ -2991,6 +2996,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if
isinstance
(
layer
.
mlp
,
DeepseekV2MoE
)
}
)
self
.
capture_aux_hidden_states
=
False
@
property
def
routed_experts_weights_of_layer
(
self
):
...
...
@@ -3044,10 +3050,13 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
else
:
return
hidden_states
...
...
@@ -3755,6 +3764,20 @@ class DeepseekV2ForCausalLM(nn.Module):
num_groups
=
config
.
n_group
,
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
not
self
.
pp_group
.
is_last_rank
:
return
if
layer_ids
is
None
:
self
.
capture_aux_hidden_states
=
True
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
else
:
self
.
capture_aux_hidden_states
=
True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
AttentionBackendRegistry
.
register
(
"ascend"
,
handle_attention_ascend
)
AttentionBackendRegistry
.
register
(
"flashinfer"
,
handle_attention_flashinfer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment