Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8c5930f0
"vscode:/vscode.git/clone" did not exist on "e44fc75acb6ddf5a331d7ef9896c0e39d87a019e"
Unverified
Commit
8c5930f0
authored
Sep 08, 2025
by
cicirori
Committed by
GitHub
Sep 07, 2025
Browse files
Add speculator attention backend switch (#9981)
parent
3b99f23c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
130 additions
and
54 deletions
+130
-54
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+58
-50
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+9
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+8
-4
test/srt/test_hybrid_attn_backend.py
test/srt/test_hybrid_attn_backend.py
+46
-0
No files found.
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
8c5930f0
...
@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend):
self
.
prefill_backend
=
prefill_backend
self
.
prefill_backend
=
prefill_backend
self
.
decode_backend
=
decode_backend
self
.
decode_backend
=
decode_backend
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
_select_backend
(
self
,
forward_mode
:
ForwardMode
)
->
AttentionBackend
:
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
"""
self
.
decode_backend
.
init_forward_metadata
(
forward_batch
)
Select the appropriate attention backend based on the forward mode.
Args:
forward_mode: The current forward mode indicating the operation type
Returns:
The selected attention backend (prefill or decode)
Note:
- decode_or_idle: Always uses decode backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
- prefill: Always uses prefill backend
"""
if
forward_mode
.
is_decode_or_idle
():
return
self
.
decode_backend
elif
forward_mode
.
is_target_verify
()
or
forward_mode
.
is_draft_extend
():
return
(
self
.
decode_backend
if
self
.
model_runner
.
server_args
.
speculative_attention_backend
==
"decode"
else
self
.
prefill_backend
)
else
:
else
:
self
.
prefill_backend
.
init_forward_metadata
(
forward_batch
)
return
self
.
prefill_backend
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
backend
=
self
.
_select_backend
(
forward_batch
.
forward_mode
)
backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
if
self
.
model_runner
.
server_args
.
speculative_algorithm
is
not
None
:
if
(
# When speculative decoding is enabled, we also need to initialize the
self
.
model_runner
.
server_args
.
speculative_algorithm
is
not
None
# prefill backend's cuda graph state to support target_verify.
and
self
.
model_runner
.
server_args
.
speculative_attention_backend
==
"prefill"
):
# When speculative decoding is enabled, we need to initialize the backend
# that will be used for target_verify.
self
.
prefill_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
self
.
prefill_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
...
@@ -45,18 +73,8 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -45,18 +73,8 @@ class HybridAttnBackend(AttentionBackend):
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
if
forward_mode
.
is_decode_or_idle
():
backend
=
self
.
_select_backend
(
forward_mode
)
self
.
decode_backend
.
init_forward_metadata_capture_cuda_graph
(
backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
else
:
self
.
prefill_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
req_pool_indices
,
req_pool_indices
,
...
@@ -77,19 +95,8 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -77,19 +95,8 @@ class HybridAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
backend
=
self
.
_select_backend
(
forward_mode
)
self
.
decode_backend
.
init_forward_metadata_replay_cuda_graph
(
backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
else
:
self
.
prefill_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
...
@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend):
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
):
):
return
self
.
prefill_backend
.
forward_extend
(
backend
=
self
.
_select_backend
(
forward_batch
.
forward_mode
)
return
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
**
kwargs
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
**
kwargs
)
)
python/sglang/srt/managers/schedule_batch.py
View file @
8c5930f0
...
@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"sampling_backend"
,
"sampling_backend"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_acc"
,
"speculative_accept_threshold_acc"
,
"speculative_attention_backend"
,
"torchao_config"
,
"torchao_config"
,
"triton_attention_reduce_in_fp32"
,
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
"num_reserved_decode_tokens"
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
8c5930f0
...
@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
# Determine attention backend used by current forward batch
# Determine attention backend used by current forward batch
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
attention_backend
=
global_server_args_dict
[
"decode_attention_backend"
]
attention_backend
=
global_server_args_dict
[
"decode_attention_backend"
]
elif
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if
global_server_args_dict
[
"speculative_attention_backend"
]
==
"decode"
:
attention_backend
=
global_server_args_dict
[
"decode_attention_backend"
]
else
:
# default to prefill
attention_backend
=
global_server_args_dict
[
"prefill_attention_backend"
]
else
:
else
:
attention_backend
=
global_server_args_dict
[
"prefill_attention_backend"
]
attention_backend
=
global_server_args_dict
[
"prefill_attention_backend"
]
self
.
current_attention_backend
=
attention_backend
self
.
current_attention_backend
=
attention_backend
...
...
python/sglang/srt/server_args.py
View file @
8c5930f0
...
@@ -262,6 +262,7 @@ class ServerArgs:
...
@@ -262,6 +262,7 @@ class ServerArgs:
speculative_accept_threshold_single
:
float
=
1.0
speculative_accept_threshold_single
:
float
=
1.0
speculative_accept_threshold_acc
:
float
=
1.0
speculative_accept_threshold_acc
:
float
=
1.0
speculative_token_map
:
Optional
[
str
]
=
None
speculative_token_map
:
Optional
[
str
]
=
None
speculative_attention_backend
:
str
=
"prefill"
# Expert parallelism
# Expert parallelism
ep_size
:
int
=
1
ep_size
:
int
=
1
...
@@ -1561,6 +1562,13 @@ class ServerArgs:
...
@@ -1561,6 +1562,13 @@ class ServerArgs:
help
=
"The path of the draft model's small vocab table."
,
help
=
"The path of the draft model's small vocab table."
,
default
=
ServerArgs
.
speculative_token_map
,
default
=
ServerArgs
.
speculative_token_map
,
)
)
parser
.
add_argument
(
"--speculative-attention-backend"
,
type
=
str
,
choices
=
[
"prefill"
,
"decode"
],
help
=
"Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'."
,
default
=
ServerArgs
.
speculative_attention_backend
,
)
# Expert parallelism
# Expert parallelism
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
8c5930f0
...
@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
# Initialize decode attention backend
# Initialize decode attention backend
self
.
draft_attn_backend
=
self
.
_create_decode_backend
()
self
.
draft_attn_backend
=
self
.
_create_decode_backend
()
# Initialize
prefill
attention backend
# Initialize
draft extend
attention backend
(respects speculative_attention_backend setting)
self
.
draft_extend_attn_backend
=
self
.
_create_draft_extend_backend
()
self
.
draft_extend_attn_backend
=
self
.
_create_draft_extend_backend
()
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
draft_model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
...
@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker):
...
@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker):
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
}
backend_name
=
(
"decode_attention_backend"
if
self
.
server_args
.
speculative_attention_backend
==
"decode"
else
"prefill_attention_backend"
)
return
self
.
_create_backend
(
return
self
.
_create_backend
(
"prefill_attention_
backend
"
,
backend
_name
,
backend_map
,
backend_map
,
"EAGLE is not supported in
prefill
attention backend {backend_type}"
,
"EAGLE is not supported in attention backend {backend_type}"
,
)
)
def
_create_flashinfer_decode_backend
(
self
):
def
_create_flashinfer_decode_backend
(
self
):
...
...
test/srt/test_hybrid_attn_backend.py
View file @
8c5930f0
...
@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
...
@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
]
]
class
TestHybridAttnBackendSpeculativeDecodingPrefillBackend
(
TestHybridAttnBackendBase
):
speculative_decode
=
True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold
=
0.2
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"4"
,
"--speculative-attention-backend"
,
"prefill"
,
]
class
TestHybridAttnBackendSpeculativeDecodingDecodeBackend
(
TestHybridAttnBackendBase
):
speculative_decode
=
True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold
=
0.2
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"4"
,
"--speculative-attention-backend"
,
"decode"
,
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment