Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
4a4772ae
"vscode:/vscode.git/clone" did not exist on "23fba672e8156ce19cc518470a5452a9543c56b9"
Unverified
Commit
4a4772ae
authored
Aug 28, 2025
by
Qiaolin Yu
Committed by
GitHub
Aug 28, 2025
Browse files
Support speculative decoding in hybrid attention backend (#9573)
parent
c3779233
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
26 deletions
+83
-26
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+53
-21
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-3
test/srt/test_hybrid_attn_backend.py
test/srt/test_hybrid_attn_backend.py
+29
-2
No files found.
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
4a4772ae
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
"""Support different backends for prefill and decode."""
def
__init__
(
def
__init__
(
self
,
prefill_backend
:
AttentionBackend
,
decode_backend
:
AttentionBackend
self
,
model_runner
:
ModelRunner
,
prefill_backend
:
AttentionBackend
,
decode_backend
:
AttentionBackend
,
):
):
self
.
model_runner
=
model_runner
self
.
prefill_backend
=
prefill_backend
self
.
prefill_backend
=
prefill_backend
self
.
decode_backend
=
decode_backend
self
.
decode_backend
=
decode_backend
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
_or_idle
():
self
.
decode_backend
.
init_forward_metadata
(
forward_batch
)
self
.
decode_backend
.
init_forward_metadata
(
forward_batch
)
else
:
else
:
self
.
prefill_backend
.
init_forward_metadata
(
forward_batch
)
self
.
prefill_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
self
.
decode_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
if
self
.
model_runner
.
server_args
.
speculative_algorithm
is
not
None
:
# When speculative decoding is enabled, we also need to initialize the
# prefill backend's cuda graph state to support target_verify.
self
.
prefill_backend
.
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
...
@@ -36,6 +45,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -36,6 +45,7 @@ class HybridAttnBackend(AttentionBackend):
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
if
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
decode_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
...
@@ -45,6 +55,16 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -45,6 +55,16 @@ class HybridAttnBackend(AttentionBackend):
forward_mode
,
forward_mode
,
spec_info
,
spec_info
,
)
)
else
:
self
.
prefill_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
...
@@ -57,6 +77,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -57,6 +77,7 @@ class HybridAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
self
.
decode_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
decode_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
...
@@ -67,6 +88,17 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -67,6 +88,17 @@ class HybridAttnBackend(AttentionBackend):
spec_info
,
spec_info
,
seq_lens_cpu
,
seq_lens_cpu
,
)
)
else
:
self
.
prefill_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
self
.
decode_backend
.
get_cuda_graph_seq_len_fill_value
()
return
self
.
decode_backend
.
get_cuda_graph_seq_len_fill_value
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
4a4772ae
...
@@ -1440,14 +1440,12 @@ class ModelRunner:
...
@@ -1440,14 +1440,12 @@ class ModelRunner:
else
self
.
server_args
.
attention_backend
else
self
.
server_args
.
attention_backend
)
)
if
self
.
decode_attention_backend_str
!=
self
.
prefill_attention_backend_str
:
if
self
.
decode_attention_backend_str
!=
self
.
prefill_attention_backend_str
:
assert
(
self
.
server_args
.
speculative_algorithm
is
None
),
"Currently HybridAttentionBackend does not support speculative decoding."
from
sglang.srt.layers.attention.hybrid_attn_backend
import
(
from
sglang.srt.layers.attention.hybrid_attn_backend
import
(
HybridAttnBackend
,
HybridAttnBackend
,
)
)
attn_backend
=
HybridAttnBackend
(
attn_backend
=
HybridAttnBackend
(
self
,
decode_backend
=
self
.
_get_attention_backend_from_str
(
decode_backend
=
self
.
_get_attention_backend_from_str
(
self
.
decode_attention_backend_str
self
.
decode_attention_backend_str
),
),
...
...
test/srt/test_hybrid_attn_backend.py
View file @
4a4772ae
...
@@ -7,6 +7,8 @@ import requests
...
@@ -7,6 +7,8 @@ import requests
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
...
@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
base_url
=
DEFAULT_URL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.65
# derived tests need to override this
accuracy_threshold
=
0.65
# derived tests need to override this
speculative_decode
=
False
speculative_decode
=
False
spec_decode_threshold
=
1.0
# derived spec decoding tests need to override this
spec_decode_threshold
=
2.2
# derived spec decoding tests need to override this
@
classmethod
@
classmethod
def
get_server_args
(
cls
):
def
get_server_args
(
cls
):
...
@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
...
@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
# please don't do this if you want to make your inference workload faster
# please don't do this if you want to make your inference workload faster
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
if
cls
.
speculative_decode
:
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
else
:
model
=
cls
.
model
cls
.
process
=
popen_launch_server
(
cls
.
process
=
popen_launch_server
(
cls
.
model
,
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
other_args
=
cls
.
get_server_args
(),
...
@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
...
@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
return
DEFAULT_SERVER_ARGS
+
[
"--enable-torch-compile"
]
return
DEFAULT_SERVER_ARGS
+
[
"--enable-torch-compile"
]
class
TestHybridAttnBackendSpeculativeDecoding
(
TestHybridAttnBackendBase
):
speculative_decode
=
True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold
=
0.2
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"3"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"4"
,
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment