Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
750940ae
Unverified
Commit
750940ae
authored
Oct 29, 2025
by
Rain H
Committed by
GitHub
Oct 29, 2025
Browse files
Eagle3 DP attention for Qwen3 MoE (#12002)
parent
42f8ea40
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
219 additions
and
27 deletions
+219
-27
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+23
-1
python/sglang/srt/models/llama_eagle3.py
python/sglang/srt/models/llama_eagle3.py
+11
-1
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+30
-15
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+16
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+6
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_eagle_dp_attention.py
test/srt/test_eagle_dp_attention.py
+129
-0
No files found.
python/sglang/srt/layers/communicator.py
View file @
750940ae
...
...
@@ -15,7 +15,7 @@
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
partial
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
torch
...
...
@@ -216,6 +216,28 @@ class LayerCommunicator:
get_global_server_args
().
speculative_algorithm
)
def
prepare_attn_and_capture_last_layer_outputs
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
captured_last_layer_outputs
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
hidden_states
,
residual
=
self
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
if
captured_last_layer_outputs
is
not
None
:
gathered_last_layer_output
=
self
.
_communicate_simple_fn
(
hidden_states
=
residual
,
forward_batch
=
forward_batch
,
context
=
self
.
_context
,
)
if
gathered_last_layer_output
is
residual
:
# Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
gathered_last_layer_output
=
residual
.
clone
()
captured_last_layer_outputs
.
append
(
gathered_last_layer_output
)
return
hidden_states
,
residual
def
prepare_attn
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
python/sglang/srt/models/llama_eagle3.py
View file @
750940ae
...
...
@@ -19,6 +19,7 @@ from sglang.srt.utils import add_prefix
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
import
copy
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
...
...
@@ -161,6 +162,10 @@ class LlamaModel(nn.Module):
if
hidden_states
.
shape
[
-
1
]
!=
embeds
.
shape
[
-
1
]:
hidden_states
=
self
.
fc
(
hidden_states
)
# idle batch
if
hidden_states
.
shape
[
0
]
==
0
:
return
hidden_states
,
[
hidden_states
]
residual
=
None
hidden_states
,
residual
=
self
.
midlayer
(
positions
,
...
...
@@ -212,7 +217,12 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
config_
=
copy
.
deepcopy
(
config
)
config_
.
vocab_size
=
(
config_
.
draft_vocab_size
)
# draft logits processor has it's own vocab size
self
.
logits_processor
=
LogitsProcessor
(
config_
)
self
.
capture_aux_hidden_states
=
True
self
.
hot_token_id
=
None
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
750940ae
...
...
@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
captured_last_layer_outputs
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
=
(
self
.
layer_communicator
.
prepare_attn_and_capture_last_layer_outputs
(
hidden_states
,
residual
,
forward_batch
,
captured_last_layer_outputs
=
captured_last_layer_outputs
,
)
)
if
hidden_states
.
shape
[
0
]
!=
0
:
...
...
@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module):
# For EAGLE3 support
self
.
layers_to_capture
=
[]
def
set_eagle3_layers_to_capture
(
self
,
layers_to_capture
:
List
[
int
]):
self
.
layers_to_capture
=
layers_to_capture
for
layer_id
in
self
.
layers_to_capture
:
setattr
(
self
.
layers
[
layer_id
],
"_is_layer_to_capture"
,
True
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module):
)
else
:
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
)
ctx
=
(
nullcontext
()
if
get_global_server_args
().
enable_piecewise_cuda_graph
...
...
@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module):
with
ctx
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
,
captured_last_layer_outputs
=
(
aux_hidden_states
if
getattr
(
layer
,
"_is_layer_to_capture"
,
False
)
else
None
),
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
...
...
@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module):
self
.
capture_aux_hidden_states
=
True
if
layer_ids
is
None
:
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
# Specific layers for EAGLE3 support
self
.
model
.
set_eagle3_layers_to_capture
(
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
)
# Specific layers for EAGLE3 support
else
:
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
self
.
model
.
set_eagle3_
layers_to_capture
(
[
val
+
1
for
val
in
layer_ids
]
)
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/models/qwen3_moe.py
View file @
750940ae
...
...
@@ -537,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
captured_last_layer_outputs
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
=
(
self
.
layer_communicator
.
prepare_attn_and_capture_last_layer_outputs
(
hidden_states
,
residual
,
forward_batch
,
captured_last_layer_outputs
=
captured_last_layer_outputs
,
)
)
if
hidden_states
.
shape
[
0
]
!=
0
:
...
...
@@ -772,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module):
self
.
capture_aux_hidden_states
=
True
if
layer_ids
is
None
:
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
# Specific layers for EAGLE3 support
self
.
model
.
set_eagle3_layers_to_capture
(
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
)
# Specific layers for EAGLE3 support
else
:
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
self
.
model
.
set_eagle3_
layers_to_capture
(
[
val
+
1
for
val
in
layer_ids
]
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/server_args.py
View file @
750940ae
...
...
@@ -822,7 +822,7 @@ class ServerArgs:
capture_bs
=
(
list
(
range
(
1
,
9
,
1
))
+
list
(
range
(
10
,
33
,
2
))
+
list
(
range
(
40
,
6
4
,
4
))
+
list
(
range
(
40
,
6
5
,
4
))
+
list
(
range
(
72
,
257
,
8
))
+
list
(
range
(
272
,
self
.
cuda_graph_max_bs
+
1
,
16
))
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
750940ae
...
...
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import
torch
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.layers.dp_attention
import
get_attention_tp_group
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -117,7 +118,11 @@ class EAGLEWorker(TpModelWorker):
self
.
hot_token_id
=
None
# Init draft worker
with
empty_context
():
if
server_args
.
enable_dp_attention
and
self
.
speculative_algorithm
.
is_eagle3
():
ctx
=
draft_tp_context
(
get_attention_tp_group
())
else
:
ctx
=
empty_context
()
with
ctx
:
super
().
__init__
(
server_args
=
server_args
,
gpu_id
=
gpu_id
,
...
...
python/sglang/test/test_utils.py
View file @
750940ae
...
...
@@ -84,6 +84,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
=
"meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
=
"lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST
=
"Qwen/Qwen3-30B-A3B"
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST
=
"Tengyunw/qwen3_30b_moe_eagle3"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
=
"lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
=
(
"meta-llama/Llama-3.1-8B-Instruct"
...
...
test/srt/run_suite.py
View file @
750940ae
...
...
@@ -158,6 +158,7 @@ suites = {
TestFile
(
"test_load_weights_from_remote_instance.py"
,
72
),
TestFile
(
"test_patch_torch.py"
,
19
),
TestFile
(
"test_release_memory_occupation.py"
,
257
),
TestFile
(
"test_eagle_dp_attention.py"
,
200
),
],
"per-commit-4-gpu"
:
[
TestFile
(
"models/test_qwen3_next_models.py"
,
291
),
...
...
test/srt/test_eagle_dp_attention.py
0 → 100644
View file @
750940ae
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.send_one
import
BenchArgs
,
send_one_prompt
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_amd_ci
,
is_in_ci
,
kill_process_tree
,
popen_launch_server
,
write_github_step_summary
,
)
class
TestEAGLE3EngineDPAttention
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-num-steps"
,
"6"
,
"--speculative-eagle-topk"
,
"10"
,
"--speculative-num-draft-tokens"
,
"32"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST
,
"--tp-size"
,
"2"
,
"--dp-size"
,
"2"
,
"--enable-dp-attention"
,
"--enable-dp-lm-head"
,
"--moe-dense-tp-size"
,
"1"
,
"--attention-backend"
,
"fa3"
,
"--mem-fraction-static"
,
"0.75"
,
"--cuda-graph-max-bs"
,
"64"
,
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_a_gsm8k
(
self
):
"""Test GSM8K evaluation - append 'a' to run first alphabetically"""
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"
{
metrics
=
}
"
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
server_data
=
server_info
.
json
()
# Try to get avg_spec_accept_length
avg_spec_accept_length
=
None
if
"internal_states"
in
server_data
and
len
(
server_data
[
"internal_states"
])
>
0
:
internal_state
=
server_data
[
"internal_states"
][
0
]
if
"avg_spec_accept_length"
in
internal_state
:
avg_spec_accept_length
=
internal_state
[
"avg_spec_accept_length"
]
elif
"spec_accept_length"
in
internal_state
:
avg_spec_accept_length
=
internal_state
[
"spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
if
is_in_ci
():
write_github_step_summary
(
f
"### test_gsm8k (EAGLE3 DP Attention)
\n
"
f
'
{
metrics
[
"accuracy"
]
=
:.
3
f
}
\n
'
f
"
{
avg_spec_accept_length
=
:.
2
f
}
\n
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.91
)
if
avg_spec_accept_length
is
not
None
:
self
.
assertGreater
(
avg_spec_accept_length
,
2.5
)
def
test_bs_1_speed
(
self
):
"""Test batch size 1 speed with EAGLE3 DP Attention"""
args
=
BenchArgs
(
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
max_new_tokens
=
2048
)
acc_length
,
speed
=
send_one_prompt
(
args
)
print
(
f
"
{
acc_length
=
:.
2
f
}
{
speed
=
:.
2
f
}
"
)
if
is_in_ci
():
write_github_step_summary
(
f
"### test_bs_1_speed (EAGLE3 DP Attention)
\n
"
f
"
{
acc_length
=
:.
2
f
}
\n
"
f
"
{
speed
=
:.
2
f
}
token/s
\n
"
)
if
is_in_amd_ci
():
self
.
assertGreater
(
acc_length
,
2.0
)
else
:
self
.
assertGreater
(
acc_length
,
2.3
)
if
is_in_amd_ci
():
self
.
assertGreater
(
speed
,
10
)
else
:
self
.
assertGreater
(
speed
,
40
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment