Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4a4772ae
Unverified
Commit
4a4772ae
authored
Aug 28, 2025
by
Qiaolin Yu
Committed by
GitHub
Aug 28, 2025
Browse files
Support speculative decoding in hybrid attention backend (#9573)
parent
c3779233
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
26 deletions
+83
-26
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+53
-21
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-3
test/srt/test_hybrid_attn_backend.py
test/srt/test_hybrid_attn_backend.py
+29
-2
No files found.
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
4a4772ae
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
"""Support different backends for prefill and decode."""
def
__init__
(
def
__init__
(
self
,
prefill_backend
:
AttentionBackend
,
decode_backend
:
AttentionBackend
self
,
model_runner
:
ModelRunner
,
prefill_backend
:
AttentionBackend
,
decode_backend
:
AttentionBackend
,
):
):
self
.
model_runner
=
model_runner
self
.
prefill_backend
=
prefill_backend
self
.
prefill_backend
=
prefill_backend
self
.
decode_backend
=
decode_backend
self
.
decode_backend
=
decode_backend
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
_or_idle
():
self
.
decode_backend
.
init_forward_metadata
(
forward_batch
)
self
.
decode_backend
.
init_forward_metadata
(
forward_batch
)
else
:
else
:
self
.
prefill_backend
.
init_forward_metadata
(
forward_batch
)
self
.
prefill_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
if
self
.
model_runner
.
server_args
.
speculative_algorithm
is
not
None
:
# When speculative decoding is enabled, we also need to initialize the
# prefill backend's cuda graph state to support target_verify.
self
.
prefill_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
...
@@ -36,6 +45,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -36,6 +45,7 @@ class HybridAttnBackend(AttentionBackend):
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
if
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
decode_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
...
@@ -45,6 +55,16 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -45,6 +55,16 @@ class HybridAttnBackend(AttentionBackend):
forward_mode
,
forward_mode
,
spec_info
,
spec_info
,
)
)
else
:
self
.
prefill_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
...
@@ -57,6 +77,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -57,6 +77,7 @@ class HybridAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
decode_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
...
@@ -67,6 +88,17 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -67,6 +88,17 @@ class HybridAttnBackend(AttentionBackend):
spec_info
,
spec_info
,
seq_lens_cpu
,
seq_lens_cpu
,
)
)
else
:
self
.
prefill_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
self
.
decode_backend
.
get_cuda_graph_seq_len_fill_value
()
return
self
.
decode_backend
.
get_cuda_graph_seq_len_fill_value
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
4a4772ae
...
@@ -1440,14 +1440,12 @@ class ModelRunner:
...
@@ -1440,14 +1440,12 @@ class ModelRunner:
else
self
.
server_args
.
attention_backend
else
self
.
server_args
.
attention_backend
)
)
if
self
.
decode_attention_backend_str
!=
self
.
prefill_attention_backend_str
:
if
self
.
decode_attention_backend_str
!=
self
.
prefill_attention_backend_str
:
assert
(
self
.
server_args
.
speculative_algorithm
is
None
),
"Currently HybridAttentionBackend does not support speculative decoding."
from
sglang.srt.layers.attention.hybrid_attn_backend
import
(
from
sglang.srt.layers.attention.hybrid_attn_backend
import
(
HybridAttnBackend
,
HybridAttnBackend
,
)
)
attn_backend
=
HybridAttnBackend
(
attn_backend
=
HybridAttnBackend
(
self
,
decode_backend
=
self
.
_get_attention_backend_from_str
(
decode_backend
=
self
.
_get_attention_backend_from_str
(
self
.
decode_attention_backend_str
self
.
decode_attention_backend_str
),
),
...
...
test/srt/test_hybrid_attn_backend.py
View file @
4a4772ae
...
@@ -7,6 +7,8 @@ import requests
...
@@ -7,6 +7,8 @@ import requests
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
...
@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
base_url
=
DEFAULT_URL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.65
# derived tests need to override this
accuracy_threshold
=
0.65
# derived tests need to override this
speculative_decode
=
False
speculative_decode
=
False
spec_decode_threshold
=
1.0
# derived spec decoding tests need to override this
spec_decode_threshold
=
2.2
# derived spec decoding tests need to override this
@
classmethod
@
classmethod
def
get_server_args
(
cls
):
def
get_server_args
(
cls
):
...
@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
...
@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
# please don't do this if you want to make your inference workload faster
# please don't do this if you want to make your inference workload faster
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
if
cls
.
speculative_decode
:
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
else
:
model
=
cls
.
model
cls
.
process
=
popen_launch_server
(
cls
.
process
=
popen_launch_server
(
cls
.
model
,
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
other_args
=
cls
.
get_server_args
(),
...
@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
...
@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
return
DEFAULT_SERVER_ARGS
+
[
"--enable-torch-compile"
]
return
DEFAULT_SERVER_ARGS
+
[
"--enable-torch-compile"
]
class
TestHybridAttnBackendSpeculativeDecoding
(
TestHybridAttnBackendBase
):
speculative_decode
=
True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold
=
0.2
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"4"
,
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment