Unverified Commit 15ddd843 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add retry for flaky tests in CI (#4755)

parent 52029bd1
......@@ -33,7 +33,7 @@ jobs:
pip install -r docs/requirements.txt
apt-get update
apt-get install -y pandoc
apt-get update && apt-get install -y parallel
apt-get update && apt-get install -y parallel retry
- name: Setup Jupyter Kernel
run: |
......
......@@ -23,7 +23,8 @@ compile:
parallel -0 -j3 --halt soon,fail=1 ' \
NB_NAME=$$(basename {}); \
START_TIME=$$(date +%s); \
jupyter nbconvert --to notebook --execute --inplace "{}" \
retry --delay=0 --times=3 -- \
jupyter nbconvert --to notebook --execute --inplace "{}" \
--ExecutePreprocessor.timeout=600 \
--ExecutePreprocessor.kernel_name=python3; \
RET_CODE=$$?; \
......
......@@ -4,9 +4,10 @@ import unittest
import torch
from sglang.srt.layers.activation import GeluAndMul
from sglang.test.test_utils import CustomTestCase
class TestGeluAndMul(unittest.TestCase):
class TestGeluAndMul(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
......
......@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
static_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.test.test_utils import CustomTestCase
_is_cuda = torch.cuda.is_available() and torch.version.cuda
......@@ -44,7 +45,7 @@ def native_per_token_group_quant_fp8(
return x_q, x_s
class TestPerTokenGroupQuantFP8(unittest.TestCase):
class TestPerTokenGroupQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
......@@ -111,7 +112,7 @@ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
return x_q, x_s
class TestStaticQuantFP8(unittest.TestCase):
class TestStaticQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
......@@ -210,7 +211,7 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
return C
class TestW8A8BlockFP8Matmul(unittest.TestCase):
class TestW8A8BlockFP8Matmul(CustomTestCase):
if not _is_cuda:
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
......@@ -331,7 +332,7 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
).sum(dim=1)
class TestW8A8BlockFP8FusedMoE(unittest.TestCase):
class TestW8A8BlockFP8FusedMoE(CustomTestCase):
DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 33, 64, 222, 1024 * 128]
N = [128, 1024, 2048]
......
......@@ -13,6 +13,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase
# For test
......@@ -232,7 +233,7 @@ def block_dequant(
return x_dq_block
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]
......
......@@ -3,9 +3,10 @@ import unittest
import torch
from sglang.srt.utils import DynamicGradMode
from sglang.test.test_utils import CustomTestCase
class TestDynamicGradMode(unittest.TestCase):
class TestDynamicGradMode(CustomTestCase):
def test_inference(self):
# Test inference_mode
DynamicGradMode.set_inference_mode(True)
......
......@@ -4,9 +4,10 @@ import unittest
import torch
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
from sglang.test.test_utils import CustomTestCase
class TestRMSNorm(unittest.TestCase):
class TestRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
......@@ -56,7 +57,7 @@ class TestRMSNorm(unittest.TestCase):
self._run_rms_norm_test(*params)
class TestGemmaRMSNorm(unittest.TestCase):
class TestGemmaRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
......
......@@ -8,6 +8,7 @@ import random
import subprocess
import threading
import time
import traceback
import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
......@@ -998,3 +999,30 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
rank += 1
else:
raise
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
_retry_execution(
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=_get_max_retry(),
)
def _get_max_retry():
return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
def _retry_execution(fn, max_retry: int):
if max_retry == 0:
fn()
return
try:
fn()
except Exception as e:
print(
f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
)
traceback.print_exc()
_retry_execution(fn, max_retry=max_retry - 1)
......@@ -3,9 +3,10 @@ import unittest
from sglang import Anthropic, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
from sglang.test.test_utils import CustomTestCase
class TestAnthropicBackend(unittest.TestCase):
class TestAnthropicBackend(CustomTestCase):
backend = None
@classmethod
......
import unittest
import sglang as sgl
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestBind(unittest.TestCase):
class TestBind(CustomTestCase):
backend = None
@classmethod
......
......@@ -7,6 +7,7 @@ from sglang.lang.choices import (
token_length_normalized,
unconditional_likelihood_normalized,
)
from sglang.test.test_utils import CustomTestCase
MOCK_CHOICES_INPUT_DATA = {
"choices": [
......@@ -51,7 +52,7 @@ MOCK_CHOICES_INPUT_DATA = {
}
class TestChoices(unittest.TestCase):
class TestChoices(CustomTestCase):
def test_token_length_normalized(self):
"""Confirm 'antidisestablishmentarianism' is selected due to high confidences for
......
......@@ -3,9 +3,10 @@ import unittest
from sglang import LiteLLM, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
from sglang.test.test_utils import CustomTestCase
class TestAnthropicBackend(unittest.TestCase):
class TestAnthropicBackend(CustomTestCase):
chat_backend = None
@classmethod
......
......@@ -17,9 +17,10 @@ from sglang.test.test_programs import (
test_stream,
test_tool_use,
)
from sglang.test.test_utils import CustomTestCase
class TestOpenAIBackend(unittest.TestCase):
class TestOpenAIBackend(CustomTestCase):
instruct_backend = None
chat_backend = None
chat_vision_backend = None
......
......@@ -22,10 +22,10 @@ from sglang.test.test_programs import (
test_stream,
test_tool_use,
)
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestSRTBackend(unittest.TestCase):
class TestSRTBackend(CustomTestCase):
backend = None
@classmethod
......
......@@ -3,9 +3,10 @@ import unittest
import sglang as sgl
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.test.test_utils import CustomTestCase
class TestTracing(unittest.TestCase):
class TestTracing(CustomTestCase):
def test_few_shot_qa(self):
@sgl.function
def few_shot_qa(s, question):
......
......@@ -10,9 +10,10 @@ from sglang.test.test_programs import (
test_parallel_encoding,
test_stream,
)
from sglang.test.test_utils import CustomTestCase
class TestVertexAIBackend(unittest.TestCase):
class TestVertexAIBackend(CustomTestCase):
backend = None
@classmethod
......
......@@ -18,6 +18,7 @@ import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase
LORA_SETS = [
# {
......@@ -70,7 +71,7 @@ What do you know about llamas?
# PROMPTS.append(sample[0]["content"][:2000])
class TestLoRA(unittest.TestCase):
class TestLoRA(CustomTestCase):
def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing inference =======================")
......
......@@ -21,7 +21,7 @@ import torch
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
CI_LORA_MODELS = [
LoRAModelCase(
......@@ -67,7 +67,7 @@ PROMPTS = [
]
class TestLoRABackend(unittest.TestCase):
class TestLoRABackend(CustomTestCase):
def run_backend(
self,
prompt: str,
......
......@@ -21,7 +21,7 @@ import torch
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
CI_LORA_MODELS = [
LoRAModelCase(
......@@ -69,7 +69,7 @@ PROMPTS = [
BACKEND = "triton"
class TestLoRATP(unittest.TestCase):
class TestLoRATP(CustomTestCase):
def run_tp(
self,
prompt: str,
......
......@@ -19,7 +19,7 @@ from typing import List
import torch
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.test_utils import is_in_ci
from sglang.test.test_utils import CustomTestCase, is_in_ci
MULTI_LORA_MODELS = [
LoRAModelCase(
......@@ -51,7 +51,7 @@ PROMPTS = [
]
class TestMultiLoRABackend(unittest.TestCase):
class TestMultiLoRABackend(CustomTestCase):
def run_backend_batch(
self,
prompts: List[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment