Unverified Commit 4d23ba08 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify FA3 tests (#5779)

parent 6e313c1b
...@@ -30,7 +30,7 @@ suites = { ...@@ -30,7 +30,7 @@ suites = {
TestFile("test_chunked_prefill.py", 336), TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500), TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 500), TestFile("test_fa3.py", 400),
TestFile("test_fp8_kernel.py", 8), TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
...@@ -92,7 +92,7 @@ suites = { ...@@ -92,7 +92,7 @@ suites = {
TestFile("test_verl_engine.py", 100), TestFile("test_verl_engine.py", 100),
], ],
"per-commit-8-gpu": [ "per-commit-8-gpu": [
TestFile("test_local_attn.py", 100), TestFile("test_local_attn.py", 250),
], ],
"nightly": [ "nightly": [
TestFile("test_nightly_gsm8k_eval.py"), TestFile("test_nightly_gsm8k_eval.py"),
......
...@@ -3,7 +3,6 @@ import unittest ...@@ -3,7 +3,6 @@ import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests import requests
import torch
from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
...@@ -14,6 +13,7 @@ from sglang.test.test_utils import ( ...@@ -14,6 +13,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
...@@ -47,9 +47,8 @@ if OFFLINE_MODE: ...@@ -47,9 +47,8 @@ if OFFLINE_MODE:
# Default server arguments shared across all tests # Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [ DEFAULT_SERVER_ARGS = [
"--trust-remote-code", "--trust-remote-code",
"--enable-torch-compile",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--attention-backend", "--attention-backend",
"fa3", "fa3",
] ]
...@@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p ...@@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class BaseFlashAttentionTest(unittest.TestCase): class BaseFlashAttentionTest(CustomTestCase):
"""Base class for testing FlashAttention3.""" """Base class for testing FlashAttention3."""
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase): ...@@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster # disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False" os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(), other_args=cls.get_server_args(),
env=os.environ,
) )
@classmethod @classmethod
...@@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase): ...@@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_gsm8k(self): def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace( args = SimpleNamespace(
num_shots=4, num_shots=4,
num_questions=100, num_questions=100,
...@@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase): ...@@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase):
data_path=GSM_DATASET_PATH, data_path=GSM_DATASET_PATH,
) )
metrics = run_eval_few_shot_gsm8k(args) metrics = run_eval_few_shot_gsm8k(args)
print(metrics) print(f"{metrics=}")
# Use the appropriate metric key based on the test class # Use the appropriate metric key based on the test class
metric_key = "accuracy" metric_key = "accuracy"
...@@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): ...@@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
return args return args
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model = DEFAULT_MODEL_NAME_FOR_TEST
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--dtype",
"float16",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=GSM_DATASET_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model""" """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
......
...@@ -10,12 +10,13 @@ from sglang.test.test_utils import ( ...@@ -10,12 +10,13 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class TestFlashAttention3LocalAttn(unittest.TestCase): class TestFlashAttention3LocalAttn(CustomTestCase):
model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
accuracy_threshold = 0.90 accuracy_threshold = 0.90
...@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase): ...@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
return [ return [
"--trust-remote-code",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "2",
"--attention-backend", "--attention-backend",
...@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase): ...@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase): ...@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_gsm8k(self): def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace( args = SimpleNamespace(
num_shots=4, num_shots=4,
num_questions=100, num_questions=100,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment