Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ea96ac7
Unverified
Commit
5ea96ac7
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Reduce one step decode for draft model. (#11561)
parent
56222658
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
253 additions
and
211 deletions
+253
-211
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+3
-3
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+2
-2
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+2
-2
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-3
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+2
-2
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+1
-1
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+222
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+13
-194
No files found.
python/sglang/srt/layers/attention/aiter_backend.py
View file @
5ea96ac7
...
@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend:
...
@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend:
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
self
.
attn_backends
=
[]
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
AiterAttnBackend
(
AiterAttnBackend
(
model_runner
,
model_runner
,
...
@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend:
...
@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend:
self
.
page_size
,
self
.
page_size
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
...
@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend:
...
@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend:
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
)
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
5ea96ac7
...
@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
...
@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
attn_backends
=
[]
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
FlashAttentionBackend
(
FlashAttentionBackend
(
model_runner
,
model_runner
,
...
@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
...
@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
5ea96ac7
...
@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend:
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
self
.
attn_backends
:
List
[
FlashInferAttnBackend
]
=
[]
self
.
attn_backends
:
List
[
FlashInferAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
FlashInferAttnBackend
(
FlashInferAttnBackend
(
model_runner
,
model_runner
,
...
@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend:
device
=
"cuda"
,
device
=
"cuda"
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
)
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
5ea96ac7
...
@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
)
self
.
attn_backends
=
[]
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
FlashInferMLAAttnBackend
(
FlashInferMLAAttnBackend
(
model_runner
,
model_runner
,
...
@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend:
device
=
"cuda"
,
device
=
"cuda"
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
)
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
5ea96ac7
...
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
...
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
)
)
self
.
attn_backends
=
[]
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
FlashMLABackend
(
FlashMLABackend
(
model_runner
,
model_runner
,
...
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
...
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
self
.
common_template
(
forward_batch
,
call_fn
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
)
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
5ea96ac7
...
@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend:
...
@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend:
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
self
.
attn_backends
:
List
[
TritonAttnBackend
]
=
[]
self
.
attn_backends
:
List
[
TritonAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
TritonAttnBackend
(
TritonAttnBackend
(
model_runner
,
model_runner
,
...
@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend:
...
@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend:
if
call_fn
is
None
:
if
call_fn
is
None
:
return
return
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
...
@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend:
...
@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_bs
,
max_num_tokens
,
max_num_tokens
,
...
...
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
5ea96ac7
...
@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
...
@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
):
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
]
=
TRTLLMHAAttnBackend
(
self
.
attn_backends
[
i
]
=
TRTLLMHAAttnBackend
(
model_runner
,
model_runner
,
skip_prefill
=
True
,
skip_prefill
=
True
,
...
@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
...
@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
5ea96ac7
...
@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
...
@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
):
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
]
=
TRTLLMMLABackend
(
self
.
attn_backends
[
i
]
=
TRTLLMMLABackend
(
model_runner
,
model_runner
,
skip_prefill
=
True
,
skip_prefill
=
True
,
...
...
python/sglang/srt/speculative/draft_utils.py
0 → 100644
View file @
5ea96ac7
import
logging
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.utils.common
import
is_blackwell
logger
=
logging
.
getLogger
(
__name__
)
class
DraftBackendFactory
:
def
__init__
(
self
,
server_args
:
ServerArgs
,
draft_model_runner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
self
.
server_args
=
server_args
self
.
draft_model_runner
=
draft_model_runner
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
def
_create_backend
(
self
,
backend_name
:
str
,
backend_map
:
dict
,
error_template
:
str
):
backend_type
=
getattr
(
self
.
server_args
,
backend_name
)
if
backend_type
is
None
:
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
return
backend_map
[
backend_type
]()
def
create_decode_backend
(
self
):
if
self
.
speculative_num_steps
==
1
:
class
DummyAttnBackend
:
def
__init__
(
self
):
pass
def
init_forward_metadata
(
*
args
,
**
kwargs
):
pass
return
DummyAttnBackend
()
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_decode_backend
,
"triton"
:
self
.
_create_triton_decode_backend
,
"aiter"
:
self
.
_create_aiter_decode_backend
,
"fa3"
:
self
.
_create_fa3_decode_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_decode_backend
if
not
is_blackwell
()
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
}
return
self
.
_create_backend
(
"decode_attention_backend"
,
backend_map
,
"EAGLE is not supported in decode attention backend {backend_type}"
,
)
def
create_draft_extend_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_prefill_backend
,
"triton"
:
self
.
_create_triton_prefill_backend
,
"aiter"
:
self
.
_create_aiter_prefill_backend
,
"fa3"
:
self
.
_create_fa3_prefill_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_prefill_backend
if
not
is_blackwell
()
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
backend_name
=
(
"decode_attention_backend"
if
self
.
server_args
.
speculative_attention_mode
==
"decode"
else
"prefill_attention_backend"
)
return
self
.
_create_backend
(
backend_name
,
backend_map
,
"EAGLE is not supported in attention backend {backend_type}"
,
)
def
_create_flashinfer_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_triton_decode_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
return
TritonMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_aiter_decode_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterMultiStepDraftBackend
return
AiterMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_fa3_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionMultiStepBackend
,
)
return
FlashAttentionMultiStepBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashmla_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashmla_backend
import
(
FlashMLAMultiStepDraftBackend
,
)
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
TRTLLMMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashinfer_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
return
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_triton_prefill_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_aiter_prefill_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_fa3_prefill_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mha_prefill_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
TRTLLMHAAttnBackend
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_flashmla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
python/sglang/srt/speculative/eagle_worker.py
View file @
5ea96ac7
...
@@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.draft_utils
import
DraftBackendFactory
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
EAGLEDraftCudaGraphRunner
,
EAGLEDraftCudaGraphRunner
,
)
)
...
@@ -195,204 +196,22 @@ class EAGLEWorker(TpModelWorker):
...
@@ -195,204 +196,22 @@ class EAGLEWorker(TpModelWorker):
self
.
has_prefill_wrapper_verify
=
False
self
.
has_prefill_wrapper_verify
=
False
self
.
draft_extend_attn_backend
=
None
self
.
draft_extend_attn_backend
=
None
# Initialize decode attention backend
draft_backend_factory
=
DraftBackendFactory
(
self
.
draft_attn_backend
=
self
.
_create_decode_backend
()
self
.
server_args
,
self
.
draft_model_runner
,
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self
.
topk
,
self
.
draft_extend_attn_backend
=
self
.
_create_draft_extend_backend
()
self
.
speculative_num_steps
,
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
_create_backend
(
self
,
backend_name
:
str
,
backend_map
:
dict
,
error_template
:
str
):
backend_type
=
getattr
(
self
.
server_args
,
backend_name
)
if
backend_type
is
None
:
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
return
backend_map
[
backend_type
]()
def
_create_decode_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_decode_backend
,
"triton"
:
self
.
_create_triton_decode_backend
,
"aiter"
:
self
.
_create_aiter_decode_backend
,
"fa3"
:
self
.
_create_fa3_decode_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_decode_backend
if
not
is_blackwell
()
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
}
return
self
.
_create_backend
(
"decode_attention_backend"
,
backend_map
,
"EAGLE is not supported in decode attention backend {backend_type}"
,
)
def
_create_draft_extend_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_prefill_backend
,
"triton"
:
self
.
_create_triton_prefill_backend
,
"aiter"
:
self
.
_create_aiter_prefill_backend
,
"fa3"
:
self
.
_create_fa3_prefill_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_prefill_backend
if
not
is_blackwell
()
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
backend_name
=
(
"decode_attention_backend"
if
self
.
server_args
.
speculative_attention_mode
==
"decode"
else
"prefill_attention_backend"
)
return
self
.
_create_backend
(
backend_name
,
backend_map
,
"EAGLE is not supported in attention backend {backend_type}"
,
)
def
_create_flashinfer_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_triton_decode_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
return
TritonMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_aiter_decode_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterMultiStepDraftBackend
return
AiterMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_fa3_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionMultiStepBackend
,
)
return
FlashAttentionMultiStepBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashmla_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashmla_backend
import
(
FlashMLAMultiStepDraftBackend
,
)
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
TRTLLMMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
)
def
_create_flashinfer_prefill_backend
(
self
):
# Initialize decode attention backend
if
not
get_global_server_args
().
use_mla_backend
:
self
.
draft_attn_backend
=
draft_backend_factory
.
create_decode_backend
()
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
return
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_triton_prefill_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_aiter_prefill_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_fa3_prefill_backend
(
self
):
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
from
sglang.srt.layers.attention.flashattentio
n_backend
import
(
self
.
draft_extend_att
n_backend
=
(
FlashAttentionB
ackend
,
draft_backend_factory
.
create_draft_extend_b
ackend
()
)
)
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
_create_trtllm_mha_prefill_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
TRTLLMHAAttnBackend
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_flashmla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment