Unverified Commit 15ddd843 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add retry for flaky tests in CI (#4755)

parent 52029bd1
...@@ -9,11 +9,12 @@ from sglang.test.test_utils import ( ...@@ -9,11 +9,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
class TestTorchAO(unittest.TestCase): class TestTorchAO(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......
...@@ -12,13 +12,14 @@ from sglang.test.test_utils import ( ...@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci, is_in_ci,
popen_launch_server, popen_launch_server,
run_bench_one_batch, run_bench_one_batch,
) )
class TestTritonAttnBackend(unittest.TestCase): class TestTritonAttnBackend(CustomTestCase):
def test_latency(self): def test_latency(self):
output_throughput = run_bench_one_batch( output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
......
...@@ -15,9 +15,10 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import ( ...@@ -15,9 +15,10 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,
) )
from sglang.test.test_utils import CustomTestCase
class TestTritonAttention(unittest.TestCase): class TestTritonAttention(CustomTestCase):
def _set_all_seeds(self, seed): def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility.""" """Set all random seeds for reproducibility."""
......
...@@ -10,9 +10,10 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( ...@@ -10,9 +10,10 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
from sglang.test.test_utils import CustomTestCase
class TestTritonAttentionMLA(unittest.TestCase): class TestTritonAttentionMLA(CustomTestCase):
def _set_all_seeds(self, seed): def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility.""" """Set all random seeds for reproducibility."""
......
...@@ -10,6 +10,7 @@ from sglang.test.test_utils import ( ...@@ -10,6 +10,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci, is_in_ci,
popen_launch_server, popen_launch_server,
) )
...@@ -18,7 +19,7 @@ from sglang.test.test_utils import ( ...@@ -18,7 +19,7 @@ from sglang.test.test_utils import (
############################################################################### ###############################################################################
# Engine Mode Tests (Single-configuration) # Engine Mode Tests (Single-configuration)
############################################################################### ###############################################################################
class TestEngineUpdateWeightsFromDisk(unittest.TestCase): class TestEngineUpdateWeightsFromDisk(CustomTestCase):
def setUp(self): def setUp(self):
self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Initialize the engine in offline (direct) mode. # Initialize the engine in offline (direct) mode.
...@@ -70,7 +71,7 @@ class TestEngineUpdateWeightsFromDisk(unittest.TestCase): ...@@ -70,7 +71,7 @@ class TestEngineUpdateWeightsFromDisk(unittest.TestCase):
############################################################################### ###############################################################################
# HTTP Server Mode Tests (Single-configuration) # HTTP Server Mode Tests (Single-configuration)
############################################################################### ###############################################################################
class TestServerUpdateWeightsFromDisk(unittest.TestCase): class TestServerUpdateWeightsFromDisk(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...@@ -159,7 +160,7 @@ class TestServerUpdateWeightsFromDisk(unittest.TestCase): ...@@ -159,7 +160,7 @@ class TestServerUpdateWeightsFromDisk(unittest.TestCase):
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations # - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
# with tp and dp ranging from 1 to 2. # with tp and dp ranging from 1 to 2.
############################################################################### ###############################################################################
class TestUpdateWeightsFromDiskParameterized(unittest.TestCase): class TestUpdateWeightsFromDiskParameterized(CustomTestCase):
def run_common_test(self, mode, tp, dp): def run_common_test(self, mode, tp, dp):
""" """
Common test procedure for update_weights_from_disk. Common test procedure for update_weights_from_disk.
......
...@@ -33,6 +33,7 @@ from sglang.test.test_utils import ( ...@@ -33,6 +33,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci, is_in_ci,
popen_launch_server, popen_launch_server,
) )
...@@ -523,7 +524,7 @@ def test_update_weights_from_distributed( ...@@ -523,7 +524,7 @@ def test_update_weights_from_distributed(
torch.cuda.empty_cache() torch.cuda.empty_cache()
class TestUpdateWeightsFromDistributed(unittest.TestCase): class TestUpdateWeightsFromDistributed(CustomTestCase):
def test_update_weights_from_distributed(self): def test_update_weights_from_distributed(self):
......
...@@ -5,7 +5,7 @@ import unittest ...@@ -5,7 +5,7 @@ import unittest
import torch import torch
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
def test_update_weights_from_tensor(tp_size): def test_update_weights_from_tensor(tp_size):
...@@ -40,7 +40,7 @@ def test_update_weights_from_tensor(tp_size): ...@@ -40,7 +40,7 @@ def test_update_weights_from_tensor(tp_size):
), f"Memory leak detected: {memory_after - memory_before} bytes" ), f"Memory leak detected: {memory_after - memory_before} bytes"
class TestUpdateWeightsFromTensor(unittest.TestCase): class TestUpdateWeightsFromTensor(CustomTestCase):
def test_update_weights_from_tensor(self): def test_update_weights_from_tensor(self):
tp_sizes = [1, 2] tp_sizes = [1, 2]
for tp_size in tp_sizes: for tp_size in tp_sizes:
......
...@@ -27,7 +27,7 @@ from sglang.test.runners import ( ...@@ -27,7 +27,7 @@ from sglang.test.runners import (
check_close_model_outputs, check_close_model_outputs,
get_dtype_str, get_dtype_str,
) )
from sglang.test.test_utils import is_in_ci from sglang.test.test_utils import CustomTestCase, is_in_ci
_MAX_NEW_TOKENS = 8 _MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="] _PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
...@@ -73,7 +73,7 @@ ALL_OTHER_MODELS = [ ...@@ -73,7 +73,7 @@ ALL_OTHER_MODELS = [
] ]
class TestVerlEngine(unittest.TestCase): class TestVerlEngine(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
multiprocessing.set_start_method("spawn") multiprocessing.set_start_method("spawn")
......
...@@ -11,11 +11,12 @@ from sglang.test.test_utils import ( ...@@ -11,11 +11,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
class TestVertexEndpoint(unittest.TestCase): class TestVertexEndpoint(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
...@@ -18,11 +18,12 @@ from sglang.srt.utils import kill_process_tree ...@@ -18,11 +18,12 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
class TestVisionChunkedPrefill(unittest.TestCase): class TestVisionChunkedPrefill(CustomTestCase):
def prepare_video_messages(self, video_path, max_frames_num=8): def prepare_video_messages(self, video_path, max_frames_num=8):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue. # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
# The following import order will cause Segmentation fault. # The following import order will cause Segmentation fault.
......
...@@ -20,6 +20,7 @@ from sglang.srt.utils import kill_process_tree ...@@ -20,6 +20,7 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
...@@ -35,7 +36,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test ...@@ -35,7 +36,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(unittest.TestCase): class TestOpenAIVisionServer(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
...@@ -507,7 +508,7 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer): ...@@ -507,7 +508,7 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestVLMContextLengthIssue(unittest.TestCase): class TestVLMContextLengthIssue(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct" cls.model = "Qwen/Qwen2-VL-7B-Instruct"
......
...@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval ...@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
class TestW8A8(unittest.TestCase): class TestW8A8(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8" cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment