Unverified Commit 8c5930f0 authored by cicirori's avatar cicirori Committed by GitHub
Browse files

Add speculator attention backend switch (#9981)

parent 3b99f23c
......@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend):
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.decode_backend.init_forward_metadata(forward_batch)
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
"""
Select the appropriate attention backend based on the forward mode.
Args:
forward_mode: The current forward mode indicating the operation type
Returns:
The selected attention backend (prefill or decode)
Note:
- decode_or_idle: Always uses decode backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
- prefill: Always uses prefill backend
"""
if forward_mode.is_decode_or_idle():
return self.decode_backend
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
return (
self.decode_backend
if self.model_runner.server_args.speculative_attention_backend
== "decode"
else self.prefill_backend
)
else:
self.prefill_backend.init_forward_metadata(forward_batch)
return self.prefill_backend
def init_forward_metadata(self, forward_batch: ForwardBatch):
backend = self._select_backend(forward_batch.forward_mode)
backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
if self.model_runner.server_args.speculative_algorithm is not None:
# When speculative decoding is enabled, we also need to initialize the
# prefill backend's cuda graph state to support target_verify.
if (
self.model_runner.server_args.speculative_algorithm is not None
and self.model_runner.server_args.speculative_attention_backend == "prefill"
):
# When speculative decoding is enabled, we need to initialize the backend
# that will be used for target_verify.
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
......@@ -45,26 +73,16 @@ class HybridAttnBackend(AttentionBackend):
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_decode_or_idle():
self.decode_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
else:
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
backend = self._select_backend(forward_mode)
backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
......@@ -77,28 +95,17 @@ class HybridAttnBackend(AttentionBackend):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
self.decode_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
else:
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
backend = self._select_backend(forward_mode)
backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
......@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend):
save_kv_cache: bool = True,
**kwargs,
):
return self.prefill_backend.forward_extend(
backend = self._select_backend(forward_batch.forward_mode)
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
......@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"sampling_backend",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"speculative_attention_backend",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
......
......@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
# Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"]
elif (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_backend"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"]
else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"]
else:
attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend
......
......@@ -262,6 +262,7 @@ class ServerArgs:
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
speculative_attention_backend: str = "prefill"
# Expert parallelism
ep_size: int = 1
......@@ -1561,6 +1562,13 @@ class ServerArgs:
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)
parser.add_argument(
"--speculative-attention-backend",
type=str,
choices=["prefill", "decode"],
help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_backend,
)
# Expert parallelism
parser.add_argument(
......
......@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
# Initialize decode attention backend
self.draft_attn_backend = self._create_decode_backend()
# Initialize prefill attention backend
# Initialize draft extend attention backend (respects speculative_attention_backend setting)
self.draft_extend_attn_backend = self._create_draft_extend_backend()
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
......@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker):
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}
backend_name = (
"decode_attention_backend"
if self.server_args.speculative_attention_backend == "decode"
else "prefill_attention_backend"
)
return self._create_backend(
"prefill_attention_backend",
backend_name,
backend_map,
"EAGLE is not supported in prefill attention backend {backend_type}",
"EAGLE is not supported in attention backend {backend_type}",
)
def _create_flashinfer_decode_backend(self):
......
......@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
]
class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
"--speculative-attention-backend",
"prefill",
]
class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
"--speculative-attention-backend",
"decode",
]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment