Unverified Commit 188f0955 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Add Speculative Decoding Eagle3 topk > 1 (#5318)


Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
Co-authored-by: default avatarYubo Wang <yubowang2019@gmail.com>
parent eef9433b
......@@ -221,7 +221,16 @@ class ModelRunner:
server_args = self.server_args
if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
"""
We auto select the fastest attention backend according to the current offering
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 Otherwise, we will use triton backend.
"""
if not self.use_mla_backend:
if (
is_hopper_with_cuda_12_3()
......@@ -234,9 +243,7 @@ class ModelRunner:
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
server_args
):
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
......
......@@ -359,7 +359,18 @@ class ServerArgs:
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
logger.info(
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
)
if (
self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
):
logger.info(
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
)
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
......
......@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
return server_args.page_size == 1
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
......
......@@ -29,7 +29,7 @@ suites = {
TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 5),
TestFile("test_fa3.py", 200),
TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55),
......
......@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
self.assertGreater(avg_spec_accept_length, 1.5)
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model = "meta-llama/Llama-3.1-8B-Instruct"
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--dtype",
"float16",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=DATA_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment