Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8c5930f0
Unverified
Commit
8c5930f0
authored
Sep 08, 2025
by
cicirori
Committed by
GitHub
Sep 07, 2025
Browse files
Add speculator attention backend switch (#9981)
parent
3b99f23c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
130 additions
and
54 deletions
+130
-54
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+58
-50
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+9
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+8
-4
test/srt/test_hybrid_attn_backend.py
test/srt/test_hybrid_attn_backend.py
+46
-0
No files found.
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
8c5930f0
...
...
@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend):
self
.
prefill_backend
=
prefill_backend
self
.
decode_backend
=
decode_backend
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata
(
forward_batch
)
def
_select_backend
(
self
,
forward_mode
:
ForwardMode
)
->
AttentionBackend
:
"""
Select the appropriate attention backend based on the forward mode.
Args:
forward_mode: The current forward mode indicating the operation type
Returns:
The selected attention backend (prefill or decode)
Note:
- decode_or_idle: Always uses decode backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
- prefill: Always uses prefill backend
"""
if
forward_mode
.
is_decode_or_idle
():
return
self
.
decode_backend
elif
forward_mode
.
is_target_verify
()
or
forward_mode
.
is_draft_extend
():
return
(
self
.
decode_backend
if
self
.
model_runner
.
server_args
.
speculative_attention_backend
==
"decode"
else
self
.
prefill_backend
)
else
:
self
.
prefill_backend
.
init_forward_metadata
(
forward_batch
)
return
self
.
prefill_backend
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
backend
=
self
.
_select_backend
(
forward_batch
.
forward_mode
)
backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
if
self
.
model_runner
.
server_args
.
speculative_algorithm
is
not
None
:
# When speculative decoding is enabled, we also need to initialize the
# prefill backend's cuda graph state to support target_verify.
if
(
self
.
model_runner
.
server_args
.
speculative_algorithm
is
not
None
and
self
.
model_runner
.
server_args
.
speculative_attention_backend
==
"prefill"
):
# When speculative decoding is enabled, we need to initialize the backend
# that will be used for target_verify.
self
.
prefill_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -45,26 +73,16 @@ class HybridAttnBackend(AttentionBackend):
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
if
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
else
:
self
.
prefill_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
backend
=
self
.
_select_backend
(
forward_mode
)
backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
...
...
@@ -77,28 +95,17 @@ class HybridAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
else
:
self
.
prefill_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
backend
=
self
.
_select_backend
(
forward_mode
)
backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
self
.
decode_backend
.
get_cuda_graph_seq_len_fill_value
()
...
...
@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend):
save_kv_cache
:
bool
=
True
,
**
kwargs
,
):
return
self
.
prefill_backend
.
forward_extend
(
backend
=
self
.
_select_backend
(
forward_batch
.
forward_mode
)
return
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
**
kwargs
)
python/sglang/srt/managers/schedule_batch.py
View file @
8c5930f0
...
...
@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"sampling_backend"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_acc"
,
"speculative_attention_backend"
,
"torchao_config"
,
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8c5930f0
...
...
@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
# Determine attention backend used by current forward batch
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
attention_backend
=
global_server_args_dict
[
"decode_attention_backend"
]
elif
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if
global_server_args_dict
[
"speculative_attention_backend"
]
==
"decode"
:
attention_backend
=
global_server_args_dict
[
"decode_attention_backend"
]
else
:
# default to prefill
attention_backend
=
global_server_args_dict
[
"prefill_attention_backend"
]
else
:
attention_backend
=
global_server_args_dict
[
"prefill_attention_backend"
]
self
.
current_attention_backend
=
attention_backend
...
...
python/sglang/srt/server_args.py
View file @
8c5930f0
...
...
@@ -262,6 +262,7 @@ class ServerArgs:
speculative_accept_threshold_single
:
float
=
1.0
speculative_accept_threshold_acc
:
float
=
1.0
speculative_token_map
:
Optional
[
str
]
=
None
speculative_attention_backend
:
str
=
"prefill"
# Expert parallelism
ep_size
:
int
=
1
...
...
@@ -1561,6 +1562,13 @@ class ServerArgs:
help
=
"The path of the draft model's small vocab table."
,
default
=
ServerArgs
.
speculative_token_map
,
)
parser
.
add_argument
(
"--speculative-attention-backend"
,
type
=
str
,
choices
=
[
"prefill"
,
"decode"
],
help
=
"Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'."
,
default
=
ServerArgs
.
speculative_attention_backend
,
)
# Expert parallelism
parser
.
add_argument
(
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
8c5930f0
...
...
@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
# Initialize decode attention backend
self
.
draft_attn_backend
=
self
.
_create_decode_backend
()
# Initialize
prefill
attention backend
# Initialize
draft extend
attention backend
(respects speculative_attention_backend setting)
self
.
draft_extend_attn_backend
=
self
.
_create_draft_extend_backend
()
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
...
...
@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker):
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
backend_name
=
(
"decode_attention_backend"
if
self
.
server_args
.
speculative_attention_backend
==
"decode"
else
"prefill_attention_backend"
)
return
self
.
_create_backend
(
"prefill_attention_
backend
"
,
backend
_name
,
backend_map
,
"EAGLE is not supported in
prefill
attention backend {backend_type}"
,
"EAGLE is not supported in attention backend {backend_type}"
,
)
def
_create_flashinfer_decode_backend
(
self
):
...
...
test/srt/test_hybrid_attn_backend.py
View file @
8c5930f0
...
...
@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
]
class
TestHybridAttnBackendSpeculativeDecodingPrefillBackend
(
TestHybridAttnBackendBase
):
speculative_decode
=
True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold
=
0.2
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"4"
,
"--speculative-attention-backend"
,
"prefill"
,
]
class
TestHybridAttnBackendSpeculativeDecodingDecodeBackend
(
TestHybridAttnBackendBase
):
speculative_decode
=
True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold
=
0.2
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"4"
,
"--speculative-attention-backend"
,
"decode"
,
]
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment