Unverified Commit 15ddd843 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add retry for flaky tests in CI (#4755)

parent 52029bd1
......@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPenalty(unittest.TestCase):
class TestPenalty(CustomTestCase):
@classmethod
def setUpClass(cls):
......
......@@ -9,11 +9,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPyTorchSamplingBackend(unittest.TestCase):
class TestPyTorchSamplingBackend(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......
......@@ -8,6 +8,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
kill_process_tree,
popen_launch_server,
)
......@@ -59,7 +60,7 @@ def run_test(base_url, nodes):
assert res.status_code == 200
class TestRadixCacheFCFS(unittest.TestCase):
class TestRadixCacheFCFS(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
......@@ -20,11 +20,12 @@ from sglang.test.test_utils import (
DEFAULT_REASONING_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestReasoningContentAPI(unittest.TestCase):
class TestReasoningContentAPI(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
......@@ -181,7 +182,7 @@ class TestReasoningContentAPI(unittest.TestCase):
assert len(response.choices[0].message.content) > 0
class TestReasoningContentWithoutParser(unittest.TestCase):
class TestReasoningContentWithoutParser(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
......
......@@ -15,6 +15,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
......@@ -41,7 +42,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
)
class TestRegexConstrained(unittest.TestCase):
class TestRegexConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
......
......@@ -5,13 +5,13 @@ import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA = True
class TestReleaseMemoryOccupation(unittest.TestCase):
class TestReleaseMemoryOccupation(CustomTestCase):
def test_release_and_resume_occupation(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
......
......@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestRequestLengthValidation(unittest.TestCase):
class TestRequestLengthValidation(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
......
......@@ -8,11 +8,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestRetractDecode(unittest.TestCase):
class TestRetractDecode(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "1"
......@@ -40,7 +41,7 @@ class TestRetractDecode(unittest.TestCase):
self.assertGreaterEqual(metrics["score"], 0.65)
class TestRetractDecodeChunkCache(unittest.TestCase):
class TestRetractDecodeChunkCache(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "1"
......
......@@ -13,11 +13,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestSageMakerServer(unittest.TestCase):
class TestSageMakerServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
......@@ -8,9 +8,10 @@ from sglang.srt.managers.schedule_policy import (
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.test.test_utils import CustomTestCase
class TestSchedulePolicy(unittest.TestCase):
class TestSchedulePolicy(CustomTestCase):
def setUp(self):
self.tree_cache = RadixCache(None, None, False)
......
......@@ -2,9 +2,10 @@ import json
import unittest
from sglang.srt.server_args import prepare_server_args
from sglang.test.test_utils import CustomTestCase
class TestPrepareServerArgs(unittest.TestCase):
class TestPrepareServerArgs(CustomTestCase):
def test_prepare_server_args(self):
server_args = prepare_server_args(
[
......
......@@ -19,6 +19,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
......@@ -27,7 +28,7 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
class TestSessionControl(unittest.TestCase):
class TestSessionControl(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......@@ -560,7 +561,7 @@ class TestSessionControl(unittest.TestCase):
)
class TestSessionControlVision(unittest.TestCase):
class TestSessionControlVision(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-7b-ov"
......
......@@ -19,11 +19,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_VLM_MODEL_NAME,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestSkipTokenizerInit(unittest.TestCase):
class TestSkipTokenizerInit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
......@@ -20,12 +20,13 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
run_logprob_check,
)
class TestSRTEndpoint(unittest.TestCase):
class TestSRTEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
......@@ -18,10 +18,11 @@ from sglang.test.few_shot_gsm8k_engine import run_eval
from sglang.test.test_utils import (
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
CustomTestCase,
)
class TestSRTEngine(unittest.TestCase):
class TestSRTEngine(CustomTestCase):
def test_1_engine_runtime_consistency(self):
prompt = "Today is a sunny day and I like"
......
import unittest
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
class TestSRTEngineWithQuantArgs(unittest.TestCase):
class TestSRTEngineWithQuantArgs(CustomTestCase):
def test_1_quantization_args(self):
......
......@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestTorchCompile(unittest.TestCase):
class TestTorchCompile(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......
......@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestTorchCompileMoe(unittest.TestCase):
class TestTorchCompileMoe(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
......
......@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
)
class TestTorchNativeAttnBackend(unittest.TestCase):
class TestTorchNativeAttnBackend(CustomTestCase):
def test_latency(self):
output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST,
......
import unittest
from sglang.test.test_utils import is_in_ci, run_bench_one_batch
from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch
class TestTorchTP(unittest.TestCase):
class TestTorchTP(CustomTestCase):
def test_torch_native_llama(self):
output_throughput = run_bench_one_batch(
"meta-llama/Meta-Llama-3-8B",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment