Unverified Commit 4a4772ae authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Support speculative decoding in hybrid attention backend (#9573)

parent c3779233
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend): ...@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode.""" """Support different backends for prefill and decode."""
def __init__( def __init__(
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend self,
model_runner: ModelRunner,
prefill_backend: AttentionBackend,
decode_backend: AttentionBackend,
): ):
self.model_runner = model_runner
self.prefill_backend = prefill_backend self.prefill_backend = prefill_backend
self.decode_backend = decode_backend self.decode_backend = decode_backend
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode_or_idle():
self.decode_backend.init_forward_metadata(forward_batch) self.decode_backend.init_forward_metadata(forward_batch)
else: else:
self.prefill_backend.init_forward_metadata(forward_batch) self.prefill_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
if self.model_runner.server_args.speculative_algorithm is not None:
# When speculative decoding is enabled, we also need to initialize the
# prefill backend's cuda graph state to support target_verify.
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
...@@ -36,6 +45,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -36,6 +45,7 @@ class HybridAttnBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if forward_mode.is_decode_or_idle():
self.decode_backend.init_forward_metadata_capture_cuda_graph( self.decode_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
...@@ -45,6 +55,16 @@ class HybridAttnBackend(AttentionBackend): ...@@ -45,6 +55,16 @@ class HybridAttnBackend(AttentionBackend):
forward_mode, forward_mode,
spec_info, spec_info,
) )
else:
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
...@@ -57,6 +77,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -57,6 +77,7 @@ class HybridAttnBackend(AttentionBackend):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle():
self.decode_backend.init_forward_metadata_replay_cuda_graph( self.decode_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
...@@ -67,6 +88,17 @@ class HybridAttnBackend(AttentionBackend): ...@@ -67,6 +88,17 @@ class HybridAttnBackend(AttentionBackend):
spec_info, spec_info,
seq_lens_cpu, seq_lens_cpu,
) )
else:
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return self.decode_backend.get_cuda_graph_seq_len_fill_value() return self.decode_backend.get_cuda_graph_seq_len_fill_value()
......
...@@ -1440,14 +1440,12 @@ class ModelRunner: ...@@ -1440,14 +1440,12 @@ class ModelRunner:
else self.server_args.attention_backend else self.server_args.attention_backend
) )
if self.decode_attention_backend_str != self.prefill_attention_backend_str: if self.decode_attention_backend_str != self.prefill_attention_backend_str:
assert (
self.server_args.speculative_algorithm is None
), "Currently HybridAttentionBackend does not support speculative decoding."
from sglang.srt.layers.attention.hybrid_attn_backend import ( from sglang.srt.layers.attention.hybrid_attn_backend import (
HybridAttnBackend, HybridAttnBackend,
) )
attn_backend = HybridAttnBackend( attn_backend = HybridAttnBackend(
self,
decode_backend=self._get_attention_backend_from_str( decode_backend=self._get_attention_backend_from_str(
self.decode_attention_backend_str self.decode_attention_backend_str
), ),
......
...@@ -7,6 +7,8 @@ import requests ...@@ -7,6 +7,8 @@ import requests
from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase): ...@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
accuracy_threshold = 0.65 # derived tests need to override this accuracy_threshold = 0.65 # derived tests need to override this
speculative_decode = False speculative_decode = False
spec_decode_threshold = 1.0 # derived spec decoding tests need to override this spec_decode_threshold = 2.2 # derived spec decoding tests need to override this
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
...@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase): ...@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
if cls.speculative_decode:
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
else:
model = cls.model
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(), other_args=cls.get_server_args(),
...@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase): ...@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"] return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment