Unverified Commit 15ddd843 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add retry for flaky tests in CI (#4755)

parent 52029bd1
......@@ -9,11 +9,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestTorchAO(unittest.TestCase):
class TestTorchAO(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......
......@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
)
class TestTritonAttnBackend(unittest.TestCase):
class TestTritonAttnBackend(CustomTestCase):
def test_latency(self):
output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST,
......
......@@ -15,9 +15,10 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.test.test_utils import CustomTestCase
class TestTritonAttention(unittest.TestCase):
class TestTritonAttention(CustomTestCase):
def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility."""
......
......@@ -10,9 +10,10 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
from sglang.test.test_utils import CustomTestCase
class TestTritonAttentionMLA(unittest.TestCase):
class TestTritonAttentionMLA(CustomTestCase):
def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility."""
......
......@@ -10,6 +10,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
......@@ -18,7 +19,7 @@ from sglang.test.test_utils import (
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
class TestEngineUpdateWeightsFromDisk(unittest.TestCase):
class TestEngineUpdateWeightsFromDisk(CustomTestCase):
def setUp(self):
self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Initialize the engine in offline (direct) mode.
......@@ -70,7 +71,7 @@ class TestEngineUpdateWeightsFromDisk(unittest.TestCase):
###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
class TestServerUpdateWeightsFromDisk(unittest.TestCase):
class TestServerUpdateWeightsFromDisk(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......@@ -159,7 +160,7 @@ class TestServerUpdateWeightsFromDisk(unittest.TestCase):
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
# with tp and dp ranging from 1 to 2.
###############################################################################
class TestUpdateWeightsFromDiskParameterized(unittest.TestCase):
class TestUpdateWeightsFromDiskParameterized(CustomTestCase):
def run_common_test(self, mode, tp, dp):
"""
Common test procedure for update_weights_from_disk.
......
......@@ -33,6 +33,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
......@@ -523,7 +524,7 @@ def test_update_weights_from_distributed(
torch.cuda.empty_cache()
class TestUpdateWeightsFromDistributed(unittest.TestCase):
class TestUpdateWeightsFromDistributed(CustomTestCase):
def test_update_weights_from_distributed(self):
......
......@@ -5,7 +5,7 @@ import unittest
import torch
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
def test_update_weights_from_tensor(tp_size):
......@@ -40,7 +40,7 @@ def test_update_weights_from_tensor(tp_size):
), f"Memory leak detected: {memory_after - memory_before} bytes"
class TestUpdateWeightsFromTensor(unittest.TestCase):
class TestUpdateWeightsFromTensor(CustomTestCase):
def test_update_weights_from_tensor(self):
tp_sizes = [1, 2]
for tp_size in tp_sizes:
......
......@@ -27,7 +27,7 @@ from sglang.test.runners import (
check_close_model_outputs,
get_dtype_str,
)
from sglang.test.test_utils import is_in_ci
from sglang.test.test_utils import CustomTestCase, is_in_ci
_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
......@@ -73,7 +73,7 @@ ALL_OTHER_MODELS = [
]
class TestVerlEngine(unittest.TestCase):
class TestVerlEngine(CustomTestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
......
......@@ -11,11 +11,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestVertexEndpoint(unittest.TestCase):
class TestVertexEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
......@@ -18,11 +18,12 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestVisionChunkedPrefill(unittest.TestCase):
class TestVisionChunkedPrefill(CustomTestCase):
def prepare_video_messages(self, video_path, max_frames_num=8):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
# The following import order will cause Segmentation fault.
......
......@@ -20,6 +20,7 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
......@@ -35,7 +36,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(unittest.TestCase):
class TestOpenAIVisionServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
......@@ -507,7 +508,7 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
cls.base_url += "/v1"
class TestVLMContextLengthIssue(unittest.TestCase):
class TestVLMContextLengthIssue(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
......
......@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestW8A8(unittest.TestCase):
class TestW8A8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment