Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5ea96ac7
"vscode:/vscode.git/clone" did not exist on "22630ca24266f4e2837bf0a02b6f899edb575601"
Unverified
Commit
5ea96ac7
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Reduce one step decode for draft model. (#11561)
parent
56222658
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
253 additions
and
211 deletions
+253
-211
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+3
-3
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+2
-2
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+2
-2
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-3
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+2
-2
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+1
-1
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+222
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+13
-194
No files found.
python/sglang/srt/layers/attention/aiter_backend.py
View file @
5ea96ac7
...
...
@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend:
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
AiterAttnBackend
(
model_runner
,
...
...
@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend:
self
.
page_size
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
...
...
@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend:
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
5ea96ac7
...
...
@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
FlashAttentionBackend
(
model_runner
,
...
...
@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
5ea96ac7
...
...
@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend:
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
attn_backends
:
List
[
FlashInferAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
FlashInferAttnBackend
(
model_runner
,
...
...
@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend:
device
=
"cuda"
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
5ea96ac7
...
...
@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
FlashInferMLAAttnBackend
(
model_runner
,
...
...
@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend:
device
=
"cuda"
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
5ea96ac7
...
...
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
FlashMLABackend
(
model_runner
,
...
...
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
5ea96ac7
...
...
@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend:
device
=
model_runner
.
device
,
)
self
.
attn_backends
:
List
[
TritonAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
TritonAttnBackend
(
model_runner
,
...
...
@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend:
if
call_fn
is
None
:
return
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
...
...
@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
...
...
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
5ea96ac7
...
...
@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
]
=
TRTLLMHAAttnBackend
(
model_runner
,
skip_prefill
=
True
,
...
...
@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
5ea96ac7
...
...
@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
]
=
TRTLLMMLABackend
(
model_runner
,
skip_prefill
=
True
,
...
...
python/sglang/srt/speculative/draft_utils.py
0 → 100644
View file @
5ea96ac7
import
logging
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.utils.common
import
is_blackwell
logger
=
logging
.
getLogger
(
__name__
)
class
DraftBackendFactory
:
def
__init__
(
self
,
server_args
:
ServerArgs
,
draft_model_runner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
self
.
server_args
=
server_args
self
.
draft_model_runner
=
draft_model_runner
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
def
_create_backend
(
self
,
backend_name
:
str
,
backend_map
:
dict
,
error_template
:
str
):
backend_type
=
getattr
(
self
.
server_args
,
backend_name
)
if
backend_type
is
None
:
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
return
backend_map
[
backend_type
]()
def
create_decode_backend
(
self
):
if
self
.
speculative_num_steps
==
1
:
class
DummyAttnBackend
:
def
__init__
(
self
):
pass
def
init_forward_metadata
(
*
args
,
**
kwargs
):
pass
return
DummyAttnBackend
()
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_decode_backend
,
"triton"
:
self
.
_create_triton_decode_backend
,
"aiter"
:
self
.
_create_aiter_decode_backend
,
"fa3"
:
self
.
_create_fa3_decode_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_decode_backend
if
not
is_blackwell
()
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
}
return
self
.
_create_backend
(
"decode_attention_backend"
,
backend_map
,
"EAGLE is not supported in decode attention backend {backend_type}"
,
)
def
create_draft_extend_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_prefill_backend
,
"triton"
:
self
.
_create_triton_prefill_backend
,
"aiter"
:
self
.
_create_aiter_prefill_backend
,
"fa3"
:
self
.
_create_fa3_prefill_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_prefill_backend
if
not
is_blackwell
()
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
backend_name
=
(
"decode_attention_backend"
if
self
.
server_args
.
speculative_attention_mode
==
"decode"
else
"prefill_attention_backend"
)
return
self
.
_create_backend
(
backend_name
,
backend_map
,
"EAGLE is not supported in attention backend {backend_type}"
,
)
def
_create_flashinfer_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_triton_decode_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
return
TritonMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_aiter_decode_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterMultiStepDraftBackend
return
AiterMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_fa3_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionMultiStepBackend
,
)
return
FlashAttentionMultiStepBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashmla_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashmla_backend
import
(
FlashMLAMultiStepDraftBackend
,
)
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
TRTLLMMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashinfer_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
return
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_triton_prefill_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_aiter_prefill_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_fa3_prefill_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mha_prefill_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
TRTLLMHAAttnBackend
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_flashmla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
python/sglang/srt/speculative/eagle_worker.py
View file @
5ea96ac7
...
...
@@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.draft_utils
import
DraftBackendFactory
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
EAGLEDraftCudaGraphRunner
,
)
...
...
@@ -195,204 +196,22 @@ class EAGLEWorker(TpModelWorker):
self
.
has_prefill_wrapper_verify
=
False
self
.
draft_extend_attn_backend
=
None
# Initialize decode attention backend
self
.
draft_attn_backend
=
self
.
_create_decode_backend
()
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self
.
draft_extend_attn_backend
=
self
.
_create_draft_extend_backend
()
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
_create_backend
(
self
,
backend_name
:
str
,
backend_map
:
dict
,
error_template
:
str
):
backend_type
=
getattr
(
self
.
server_args
,
backend_name
)
if
backend_type
is
None
:
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
return
backend_map
[
backend_type
]()
def
_create_decode_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_decode_backend
,
"triton"
:
self
.
_create_triton_decode_backend
,
"aiter"
:
self
.
_create_aiter_decode_backend
,
"fa3"
:
self
.
_create_fa3_decode_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_decode_backend
if
not
is_blackwell
()
else
self
.
_create_triton_decode_backend
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
}
return
self
.
_create_backend
(
"decode_attention_backend"
,
backend_map
,
"EAGLE is not supported in decode attention backend {backend_type}"
,
)
def
_create_draft_extend_backend
(
self
):
backend_map
=
{
"flashinfer"
:
self
.
_create_flashinfer_prefill_backend
,
"triton"
:
self
.
_create_triton_prefill_backend
,
"aiter"
:
self
.
_create_aiter_prefill_backend
,
"fa3"
:
self
.
_create_fa3_prefill_backend
,
"hybrid_linear_attn"
:
(
self
.
_create_fa3_prefill_backend
if
not
is_blackwell
()
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
backend_name
=
(
"decode_attention_backend"
if
self
.
server_args
.
speculative_attention_mode
==
"decode"
else
"prefill_attention_backend"
)
return
self
.
_create_backend
(
backend_name
,
backend_map
,
"EAGLE is not supported in attention backend {backend_type}"
,
)
def
_create_flashinfer_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
FlashInferMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_triton_decode_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
return
TritonMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_aiter_decode_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterMultiStepDraftBackend
return
AiterMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_fa3_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionMultiStepBackend
,
)
return
FlashAttentionMultiStepBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_flashmla_decode_backend
(
self
):
from
sglang.srt.layers.attention.flashmla_backend
import
(
FlashMLAMultiStepDraftBackend
,
)
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
TRTLLMMLAMultiStepDraftBackend
,
)
self
.
has_prefill_wrapper_verify
=
True
return
TRTLLMMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
draft_backend_factory
=
DraftBackendFactory
(
self
.
server_args
,
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
def
_create_flashinfer_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
)
return
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
)
return
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_triton_prefill_backend
(
self
):
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
return
TritonAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_aiter_prefill_backend
(
self
):
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
return
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
# Initialize decode attention backend
self
.
draft_attn_backend
=
draft_backend_factory
.
create_decode_backend
()
def
_create_fa3_prefill_backend
(
self
):
from
sglang.srt.layers.attention.flashattentio
n_backend
import
(
FlashAttentionB
ackend
,
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self
.
draft_extend_att
n_backend
=
(
draft_backend_factory
.
create_draft_extend_b
ackend
()
)
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mha_prefill_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
TRTLLMHAAttnBackend
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
get_global_server_args
().
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_flashmla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment