"vscode:/vscode.git/clone" did not exist on "1c26be904ca7326510428f66e856c49a7a9324f0"
Unverified Commit 8ad700f7 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Cleaning codes for speculative attention mode (#10149)

parent 148022fc
...@@ -209,6 +209,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -209,6 +209,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--speculative-accept-threshold-single` | Accept a draft token if its probability in the target model is greater than this threshold. | 1.0 | | `--speculative-accept-threshold-single` | Accept a draft token if its probability in the target model is greater than this threshold. | 1.0 |
| `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | 1.0 | | `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | 1.0 |
| `--speculative-token-map` | The path of the draft model's small vocab table. | None | | `--speculative-token-map` | The path of the draft model's small vocab table. | None |
| `--speculative-attention-mode` | Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'. | Prefill |
## Expert parallelism ## Expert parallelism
......
...@@ -34,7 +34,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -34,7 +34,7 @@ class HybridAttnBackend(AttentionBackend):
Note: Note:
- decode_or_idle: Always uses decode backend - decode_or_idle: Always uses decode backend
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend - target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
- prefill: Always uses prefill backend - prefill: Always uses prefill backend
""" """
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -42,8 +42,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -42,8 +42,7 @@ class HybridAttnBackend(AttentionBackend):
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend(): elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
return ( return (
self.decode_backend self.decode_backend
if self.model_runner.server_args.speculative_attention_backend if self.model_runner.server_args.speculative_attention_mode == "decode"
== "decode"
else self.prefill_backend else self.prefill_backend
) )
else: else:
...@@ -57,7 +56,7 @@ class HybridAttnBackend(AttentionBackend): ...@@ -57,7 +56,7 @@ class HybridAttnBackend(AttentionBackend):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens) self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
if ( if (
self.model_runner.server_args.speculative_algorithm is not None self.model_runner.server_args.speculative_algorithm is not None
and self.model_runner.server_args.speculative_attention_backend == "prefill" and self.model_runner.server_args.speculative_attention_mode == "prefill"
): ):
# When speculative decoding is enabled, we need to initialize the backend # When speculative decoding is enabled, we need to initialize the backend
# that will be used for target_verify. # that will be used for target_verify.
......
...@@ -98,7 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -98,7 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"sampling_backend", "sampling_backend",
"speculative_accept_threshold_single", "speculative_accept_threshold_single",
"speculative_accept_threshold_acc", "speculative_accept_threshold_acc",
"speculative_attention_backend", "speculative_attention_mode",
"torchao_config", "torchao_config",
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
......
...@@ -1050,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1050,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
or forward_batch.forward_mode.is_draft_extend() or forward_batch.forward_mode.is_draft_extend()
): ):
# Use the specified backend for speculative operations (both verify and draft extend) # Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_backend"] == "decode": if global_server_args_dict["speculative_attention_mode"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"] attention_backend = global_server_args_dict["decode_attention_backend"]
else: # default to prefill else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = global_server_args_dict["prefill_attention_backend"]
......
...@@ -262,7 +262,7 @@ class ServerArgs: ...@@ -262,7 +262,7 @@ class ServerArgs:
speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0 speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
speculative_attention_backend: str = "prefill" speculative_attention_mode: str = "prefill"
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
...@@ -1563,11 +1563,11 @@ class ServerArgs: ...@@ -1563,11 +1563,11 @@ class ServerArgs:
default=ServerArgs.speculative_token_map, default=ServerArgs.speculative_token_map,
) )
parser.add_argument( parser.add_argument(
"--speculative-attention-backend", "--speculative-attention-mode",
type=str, type=str,
choices=["prefill", "decode"], choices=["prefill", "decode"],
help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.", help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_backend, default=ServerArgs.speculative_attention_mode,
) )
# Expert parallelism # Expert parallelism
......
...@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
# Initialize decode attention backend # Initialize decode attention backend
self.draft_attn_backend = self._create_decode_backend() self.draft_attn_backend = self._create_decode_backend()
# Initialize draft extend attention backend (respects speculative_attention_backend setting) # Initialize draft extend attention backend (respects speculative_attention_mode setting)
self.draft_extend_attn_backend = self._create_draft_extend_backend() self.draft_extend_attn_backend = self._create_draft_extend_backend()
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
...@@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -236,7 +236,7 @@ class EAGLEWorker(TpModelWorker):
} }
backend_name = ( backend_name = (
"decode_attention_backend" "decode_attention_backend"
if self.server_args.speculative_attention_backend == "decode" if self.server_args.speculative_attention_mode == "decode"
else "prefill_attention_backend" else "prefill_attention_backend"
) )
return self._create_backend( return self._create_backend(
......
...@@ -111,27 +111,6 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase): ...@@ -111,27 +111,6 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"] return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
]
class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase): class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase):
speculative_decode = True speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low. # This eagle test uses a very small model, so the accuracy is low.
...@@ -150,7 +129,7 @@ class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBacke ...@@ -150,7 +129,7 @@ class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBacke
"2", "2",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"4", "4",
"--speculative-attention-backend", "--speculative-attention-mode",
"prefill", "prefill",
] ]
...@@ -173,7 +152,7 @@ class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBacken ...@@ -173,7 +152,7 @@ class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBacken
"2", "2",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"4", "4",
"--speculative-attention-backend", "--speculative-attention-mode",
"decode", "decode",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment