Unverified Commit 9ecb1856 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix triton sliding window test case (#6981)

parent cc74499d
import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -10,6 +9,7 @@ from sglang.test.test_utils import ( ...@@ -10,6 +9,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_ci,
popen_launch_server, popen_launch_server,
) )
...@@ -45,10 +45,6 @@ class TestSlidingWindowAttentionTriton(CustomTestCase): ...@@ -45,10 +45,6 @@ class TestSlidingWindowAttentionTriton(CustomTestCase):
) )
cls.long_context_prompt += "\nNow, summarize the story in one sentence:" cls.long_context_prompt += "\nNow, summarize the story in one sentence:"
@classmethod
def tearDownClass(cls):
pass
def _test_mmlu(self): def _test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
...@@ -61,7 +57,7 @@ class TestSlidingWindowAttentionTriton(CustomTestCase): ...@@ -61,7 +57,7 @@ class TestSlidingWindowAttentionTriton(CustomTestCase):
metrics = run_eval(args) metrics = run_eval(args)
print(f"MMLU metrics with sliding window: {metrics}") print(f"MMLU metrics with sliding window: {metrics}")
self.assertGreaterEqual(metrics["score"], 0.61) self.assertGreaterEqual(metrics["score"], 0.60)
def _test_short_context_generation(self): def _test_short_context_generation(self):
response = requests.post( response = requests.post(
...@@ -97,6 +93,7 @@ class TestSlidingWindowAttentionTriton(CustomTestCase): ...@@ -97,6 +93,7 @@ class TestSlidingWindowAttentionTriton(CustomTestCase):
self.assertGreater(len(result["text"].strip()), 0) self.assertGreater(len(result["text"].strip()), 0)
print(f"Long context generation result: {result['text'][:100]}...") print(f"Long context generation result: {result['text'][:100]}...")
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
def test_no_cuda_graph(self): def test_no_cuda_graph(self):
self.no_cuda_graph_process = popen_launch_server( self.no_cuda_graph_process = popen_launch_server(
self.model, self.model,
...@@ -105,12 +102,12 @@ class TestSlidingWindowAttentionTriton(CustomTestCase): ...@@ -105,12 +102,12 @@ class TestSlidingWindowAttentionTriton(CustomTestCase):
other_args=self.common_args + ["--disable-cuda-graph"], other_args=self.common_args + ["--disable-cuda-graph"],
) )
self._test_short_context_generation() try:
self._test_long_context_generation() self._test_short_context_generation()
self._test_mmlu() self._test_long_context_generation()
self._test_mmlu()
kill_process_tree(self.no_cuda_graph_process.pid) finally:
time.sleep(5) kill_process_tree(self.no_cuda_graph_process.pid)
def test_cuda_graph(self): def test_cuda_graph(self):
self.cuda_graph_process = popen_launch_server( self.cuda_graph_process = popen_launch_server(
...@@ -120,12 +117,12 @@ class TestSlidingWindowAttentionTriton(CustomTestCase): ...@@ -120,12 +117,12 @@ class TestSlidingWindowAttentionTriton(CustomTestCase):
other_args=self.common_args, other_args=self.common_args,
) )
self._test_short_context_generation() try:
self._test_long_context_generation() self._test_short_context_generation()
self._test_mmlu() self._test_long_context_generation()
self._test_mmlu()
kill_process_tree(self.cuda_graph_process.pid) finally:
time.sleep(5) kill_process_tree(self.cuda_graph_process.pid)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment