"vscode:/vscode.git/clone" did not exist on "e44fc75acb6ddf5a331d7ef9896c0e39d87a019e"
Unverified Commit 8c5930f0 authored by cicirori's avatar cicirori Committed by GitHub
Browse files

Add speculator attention backend switch (#9981)

parent 3b99f23c
...@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend): ...@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend):
self.prefill_backend = prefill_backend self.prefill_backend = prefill_backend
self.decode_backend = decode_backend self.decode_backend = decode_backend
def init_forward_metadata(self, forward_batch: ForwardBatch): def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
if forward_batch.forward_mode.is_decode_or_idle(): """
self.decode_backend.init_forward_metadata(forward_batch) Select the appropriate attention backend based on the forward mode.
Args:
forward_mode: The current forward mode indicating the operation type
Returns:
The selected attention backend (prefill or decode)
Note:
- decode_or_idle: Always uses decode backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
- prefill: Always uses prefill backend
"""
if forward_mode.is_decode_or_idle():
return self.decode_backend
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
return (
self.decode_backend
if self.model_runner.server_args.speculative_attention_backend
== "decode"
else self.prefill_backend
)
else: else:
self.prefill_backend.init_forward_metadata(forward_batch) return self.prefill_backend
def init_forward_metadata(self, forward_batch: ForwardBatch):
backend = self._select_backend(forward_batch.forward_mode)
backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
if self.model_runner.server_args.speculative_algorithm is not None: if (
# When speculative decoding is enabled, we also need to initialize the self.model_runner.server_args.speculative_algorithm is not None
# prefill backend's cuda graph state to support target_verify. and self.model_runner.server_args.speculative_attention_backend == "prefill"
):
# When speculative decoding is enabled, we need to initialize the backend
# that will be used for target_verify.
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens) self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
...@@ -45,18 +73,8 @@ class HybridAttnBackend(AttentionBackend): ...@@ -45,18 +73,8 @@ class HybridAttnBackend(AttentionBackend):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if forward_mode.is_decode_or_idle(): backend = self._select_backend(forward_mode)
self.decode_backend.init_forward_metadata_capture_cuda_graph( backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
else:
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
req_pool_indices, req_pool_indices,
...@@ -77,19 +95,8 @@ class HybridAttnBackend(AttentionBackend): ...@@ -77,19 +95,8 @@ class HybridAttnBackend(AttentionBackend):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): backend = self._select_backend(forward_mode)
self.decode_backend.init_forward_metadata_replay_cuda_graph( backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
else:
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
...@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend):
save_kv_cache: bool = True, save_kv_cache: bool = True,
**kwargs, **kwargs,
): ):
return self.prefill_backend.forward_extend( backend = self._select_backend(forward_batch.forward_mode)
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs q, k, v, layer, forward_batch, save_kv_cache, **kwargs
) )
...@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"sampling_backend", "sampling_backend",
"speculative_accept_threshold_single", "speculative_accept_threshold_single",
"speculative_accept_threshold_acc", "speculative_accept_threshold_acc",
"speculative_attention_backend",
"torchao_config", "torchao_config",
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
......
...@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
# Determine attention backend used by current forward batch # Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"] attention_backend = global_server_args_dict["decode_attention_backend"]
elif (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_backend"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"]
else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"]
else: else:
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend self.current_attention_backend = attention_backend
......
...@@ -262,6 +262,7 @@ class ServerArgs: ...@@ -262,6 +262,7 @@ class ServerArgs:
speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0 speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
speculative_attention_backend: str = "prefill"
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
...@@ -1561,6 +1562,13 @@ class ServerArgs: ...@@ -1561,6 +1562,13 @@ class ServerArgs:
help="The path of the draft model's small vocab table.", help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map, default=ServerArgs.speculative_token_map,
) )
parser.add_argument(
"--speculative-attention-backend",
type=str,
choices=["prefill", "decode"],
help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_backend,
)
# Expert parallelism # Expert parallelism
parser.add_argument( parser.add_argument(
......
...@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
# Initialize decode attention backend # Initialize decode attention backend
self.draft_attn_backend = self._create_decode_backend() self.draft_attn_backend = self._create_decode_backend()
# Initialize prefill attention backend # Initialize draft extend attention backend (respects speculative_attention_backend setting)
self.draft_extend_attn_backend = self._create_draft_extend_backend() self.draft_extend_attn_backend = self._create_draft_extend_backend()
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
...@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker): ...@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker):
"trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
} }
backend_name = (
"decode_attention_backend"
if self.server_args.speculative_attention_backend == "decode"
else "prefill_attention_backend"
)
return self._create_backend( return self._create_backend(
"prefill_attention_backend", backend_name,
backend_map, backend_map,
"EAGLE is not supported in prefill attention backend {backend_type}", "EAGLE is not supported in attention backend {backend_type}",
) )
def _create_flashinfer_decode_backend(self): def _create_flashinfer_decode_backend(self):
......
...@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase): ...@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
] ]
class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
"--speculative-attention-backend",
"prefill",
]
class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
"--speculative-attention-backend",
"decode",
]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment