Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
df397a72
Unverified
Commit
df397a72
authored
Sep 03, 2025
by
Ximingwang-09
Committed by
GitHub
Sep 02, 2025
Browse files
[feat] Add P/D attention select for draft model (#9755)
Co-authored-by:
纬杭
<
ximing.wxm@antgroup.com
>
parent
5dfcd6c2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
158 additions
and
112 deletions
+158
-112
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+158
-112
No files found.
python/sglang/srt/speculative/eagle_worker.py
View file @
df397a72
...
@@ -187,137 +187,183 @@ class EAGLEWorker(TpModelWorker):
...
@@ -187,137 +187,183 @@ class EAGLEWorker(TpModelWorker):
self
.
has_prefill_wrapper_verify
=
False
self
.
has_prefill_wrapper_verify
=
False
self
.
draft_extend_attn_backend
=
None
self
.
draft_extend_attn_backend
=
None
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
# Initialize decode attention backend
self
.
draft_attn_backend
=
self
.
_create_decode_backend
()
# Initialize prefill attention backend
self
.
draft_extend_attn_backend
=
self
.
_create_draft_extend_backend
()
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
_create_backend
(
self
,
backend_name
:
str
,
backend_map
:
dict
,
error_template
:
str
):
backend_type
=
getattr
(
self
.
server_args
,
backend_name
)
if
backend_type
is
None
:
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
return
backend_map
[
backend_type
]()
def
_create_decode_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_decode_backend
,
"triton"
:
self
.
_create_triton_decode_backend
,
"aiter"
:
self
.
_create_aiter_decode_backend
,
"fa3"
:
self
.
_create_fa3_decode_backend
,
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
}
return
self
.
_create_backend
(
"decode_attention_backend"
,
backend_map
,
"EAGLE is not supported in decode attention backend {backend_type}"
,
)
def
_create_draft_extend_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_prefill_backend
,
"triton"
:
self
.
_create_triton_prefill_backend
,
"aiter"
:
self
.
_create_aiter_prefill_backend
,
"fa3"
:
self
.
_create_fa3_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
return
self
.
_create_backend
(
"prefill_attention_backend"
,
backend_map
,
"EAGLE is not supported in prefill attention backend {backend_type}"
,
)
def
_create_flashinfer_decode_backend
(
self
):
if
not
global_server_args_dict
[
"use_mla_backend"
]:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
FlashInferMultiStepDraftBackend
,
FlashInferMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
has_prefill_wrapper_verify
=
True
self
.
draft_model_runner
,
return
FlashInferMultiStepDraftBackend
(
self
.
topk
,
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
)
else
:
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
FlashInferMLAMultiStepDraftBackend
,
FlashInferMLAMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
True
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"triton"
:
return
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_triton_decode_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
(
from
sglang.srt.layers.attention.triton_backend
import
(
TritonAttnBackend
,
TritonMultiStepDraftBackend
,
TritonMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
return
TritonMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
TritonAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
(
AiterAttnBackend
,
AiterMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
AiterMultiStepDraftBackend
(
def
_create_aiter_decode_backend
(
self
):
self
.
draft_model_runner
,
from
sglang.srt.layers.attention.aiter_backend
import
AiterMultiStepDraftBackend
self
.
topk
,
self
.
speculative_num_steps
,
return
AiterMultiStepDraftBackend
(
)
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
draft_extend_attn_backend
=
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
)
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
def
_create_fa3_decode_backend
(
self
)
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
FlashAttentionMultiStepBackend
,
FlashAttentionMultiStepBackend
,
)
)
self
.
draft_attn_backend
=
FlashAttentionMultiStepBackend
(
return
FlashAttentionMultiStepBackend
(
self
.
draft_model_runner
,
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
)
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
def
_create_flashmla_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashmla_backend
import
(
from
sglang.srt.layers.attention.flashmla_backend
import
(
FlashMLAMultiStepDraftBackend
,
FlashMLAMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
FlashMLAMultiStepDraftBackend
(
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
topk
,
self
.
speculative_num_steps
,
)
)
elif
self
.
server_args
.
attention_backend
==
"trtllm_mha"
:
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnBackend
,
TRTLLMHAAttnMultiStepDraftBackend
,
TRTLLMHAAttnMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
True
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"trtllm_mla"
:
return
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
global_server_args_dict
[
"use_mla_backend"
]:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
TRTLLMMLABackend
,
TRTLLMMLAMultiStepDraftBackend
,
TRTLLMMLAMultiStepDraftBackend
,
)
)
self
.
draft_attn_backend
=
TRTLLMMLAMultiStepDraftBackend
(
self
.
has_prefill_wrapper_verify
=
True
self
.
draft_model_runner
,
return
TRTLLMMLAMultiStepDraftBackend
(
self
.
topk
,
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
TRTLLMMLABackend
(
self
.
draft_model_runner
,
def
_create_flashinfer_prefill_backend
(
self
):
skip_prefill
=
False
,
if
not
global_server_args_dict
[
"use_mla_backend"
]:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
else
:
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_triton_prefill_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_aiter_prefill_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_fa3_prefill_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mha_prefill_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
TRTLLMHAAttnBackend
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
global_server_args_dict
[
"use_mla_backend"
]:
raise
ValueError
(
raise
ValueError
(
f
"EAGLE is not supported in attention backend
{
self
.
server_args
.
attention
_backend
}
"
"trtllm_mla backend requires MLA model (use_mla
_backend
=True).
"
)
)
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment