import os import unittest from types import SimpleNamespace import requests from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") class TestFlashAttention3LocalAttn(unittest.TestCase): model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION base_url = DEFAULT_URL_FOR_TEST accuracy_threshold = 0.90 @classmethod def get_server_args(cls): return [ "--trust-remote-code", "--cuda-graph-max-bs", "2", "--attention-backend", "fa3", "--tp", "4", "--context-length", "1000000", ] @classmethod def setUpClass(cls): # disable deep gemm precompile to make launch server faster # please don't do this if you want to make your inference workload faster cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=cls.get_server_args(), env=os.environ, ) @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): args = SimpleNamespace( num_shots=4, num_questions=100, max_new_tokens=512, parallel=128, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), data_path=None, ) metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") # Use the appropriate metric key based on the test class metric_key = "accuracy" self.assertGreater(metrics[metric_key], self.accuracy_threshold) if __name__ == "__main__": unittest.main()