Unverified Commit 15ddd843 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add retry for flaky tests in CI (#4755)

parent 52029bd1
...@@ -33,7 +33,7 @@ jobs: ...@@ -33,7 +33,7 @@ jobs:
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
apt-get update apt-get update
apt-get install -y pandoc apt-get install -y pandoc
apt-get update && apt-get install -y parallel apt-get update && apt-get install -y parallel retry
- name: Setup Jupyter Kernel - name: Setup Jupyter Kernel
run: | run: |
......
...@@ -23,7 +23,8 @@ compile: ...@@ -23,7 +23,8 @@ compile:
parallel -0 -j3 --halt soon,fail=1 ' \ parallel -0 -j3 --halt soon,fail=1 ' \
NB_NAME=$$(basename {}); \ NB_NAME=$$(basename {}); \
START_TIME=$$(date +%s); \ START_TIME=$$(date +%s); \
jupyter nbconvert --to notebook --execute --inplace "{}" \ retry --delay=0 --times=3 -- \
jupyter nbconvert --to notebook --execute --inplace "{}" \
--ExecutePreprocessor.timeout=600 \ --ExecutePreprocessor.timeout=600 \
--ExecutePreprocessor.kernel_name=python3; \ --ExecutePreprocessor.kernel_name=python3; \
RET_CODE=$$?; \ RET_CODE=$$?; \
......
...@@ -4,9 +4,10 @@ import unittest ...@@ -4,9 +4,10 @@ import unittest
import torch import torch
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.test.test_utils import CustomTestCase
class TestGeluAndMul(unittest.TestCase): class TestGeluAndMul(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824] D = [512, 4096, 5120, 13824]
......
...@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
) )
from sglang.test.test_utils import CustomTestCase
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = torch.cuda.is_available() and torch.version.cuda
...@@ -44,7 +45,7 @@ def native_per_token_group_quant_fp8( ...@@ -44,7 +45,7 @@ def native_per_token_group_quant_fp8(
return x_q, x_s return x_q, x_s
class TestPerTokenGroupQuantFP8(unittest.TestCase): class TestPerTokenGroupQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824] D = [512, 4096, 5120, 13824]
...@@ -111,7 +112,7 @@ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn): ...@@ -111,7 +112,7 @@ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
return x_q, x_s return x_q, x_s
class TestStaticQuantFP8(unittest.TestCase): class TestStaticQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824] D = [512, 4096, 5120, 13824]
...@@ -210,7 +211,7 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl ...@@ -210,7 +211,7 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
return C return C
class TestW8A8BlockFP8Matmul(unittest.TestCase): class TestW8A8BlockFP8Matmul(CustomTestCase):
if not _is_cuda: if not _is_cuda:
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
...@@ -331,7 +332,7 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): ...@@ -331,7 +332,7 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
).sum(dim=1) ).sum(dim=1)
class TestW8A8BlockFP8FusedMoE(unittest.TestCase): class TestW8A8BlockFP8FusedMoE(CustomTestCase):
DTYPES = [torch.float32, torch.half, torch.bfloat16] DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 33, 64, 222, 1024 * 128] M = [1, 33, 64, 222, 1024 * 128]
N = [128, 1024, 2048] N = [128, 1024, 2048]
......
...@@ -13,6 +13,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -13,6 +13,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
) )
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase
# For test # For test
...@@ -232,7 +233,7 @@ def block_dequant( ...@@ -232,7 +233,7 @@ def block_dequant(
return x_dq_block return x_dq_block
class TestW8A8BlockFP8EPMoE(unittest.TestCase): class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048] M = [1, 222, 1024, 2048]
N = [128, 1024, 2048] N = [128, 1024, 2048]
......
...@@ -3,9 +3,10 @@ import unittest ...@@ -3,9 +3,10 @@ import unittest
import torch import torch
from sglang.srt.utils import DynamicGradMode from sglang.srt.utils import DynamicGradMode
from sglang.test.test_utils import CustomTestCase
class TestDynamicGradMode(unittest.TestCase): class TestDynamicGradMode(CustomTestCase):
def test_inference(self): def test_inference(self):
# Test inference_mode # Test inference_mode
DynamicGradMode.set_inference_mode(True) DynamicGradMode.set_inference_mode(True)
......
...@@ -4,9 +4,10 @@ import unittest ...@@ -4,9 +4,10 @@ import unittest
import torch import torch
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
from sglang.test.test_utils import CustomTestCase
class TestRMSNorm(unittest.TestCase): class TestRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096] NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
...@@ -56,7 +57,7 @@ class TestRMSNorm(unittest.TestCase): ...@@ -56,7 +57,7 @@ class TestRMSNorm(unittest.TestCase):
self._run_rms_norm_test(*params) self._run_rms_norm_test(*params)
class TestGemmaRMSNorm(unittest.TestCase): class TestGemmaRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096] NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
......
...@@ -8,6 +8,7 @@ import random ...@@ -8,6 +8,7 @@ import random
import subprocess import subprocess
import threading import threading
import time import time
import traceback
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
...@@ -998,3 +999,30 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple): ...@@ -998,3 +999,30 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
rank += 1 rank += 1
else: else:
raise raise
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
_retry_execution(
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=_get_max_retry(),
)
def _get_max_retry():
return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
def _retry_execution(fn, max_retry: int):
if max_retry == 0:
fn()
return
try:
fn()
except Exception as e:
print(
f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
)
traceback.print_exc()
_retry_execution(fn, max_retry=max_retry - 1)
...@@ -3,9 +3,10 @@ import unittest ...@@ -3,9 +3,10 @@ import unittest
from sglang import Anthropic, set_default_backend from sglang import Anthropic, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream from sglang.test.test_programs import test_mt_bench, test_stream
from sglang.test.test_utils import CustomTestCase
class TestAnthropicBackend(unittest.TestCase): class TestAnthropicBackend(CustomTestCase):
backend = None backend = None
@classmethod @classmethod
......
import unittest import unittest
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestBind(unittest.TestCase): class TestBind(CustomTestCase):
backend = None backend = None
@classmethod @classmethod
......
...@@ -7,6 +7,7 @@ from sglang.lang.choices import ( ...@@ -7,6 +7,7 @@ from sglang.lang.choices import (
token_length_normalized, token_length_normalized,
unconditional_likelihood_normalized, unconditional_likelihood_normalized,
) )
from sglang.test.test_utils import CustomTestCase
MOCK_CHOICES_INPUT_DATA = { MOCK_CHOICES_INPUT_DATA = {
"choices": [ "choices": [
...@@ -51,7 +52,7 @@ MOCK_CHOICES_INPUT_DATA = { ...@@ -51,7 +52,7 @@ MOCK_CHOICES_INPUT_DATA = {
} }
class TestChoices(unittest.TestCase): class TestChoices(CustomTestCase):
def test_token_length_normalized(self): def test_token_length_normalized(self):
"""Confirm 'antidisestablishmentarianism' is selected due to high confidences for """Confirm 'antidisestablishmentarianism' is selected due to high confidences for
......
...@@ -3,9 +3,10 @@ import unittest ...@@ -3,9 +3,10 @@ import unittest
from sglang import LiteLLM, set_default_backend from sglang import LiteLLM, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream from sglang.test.test_programs import test_mt_bench, test_stream
from sglang.test.test_utils import CustomTestCase
class TestAnthropicBackend(unittest.TestCase): class TestAnthropicBackend(CustomTestCase):
chat_backend = None chat_backend = None
@classmethod @classmethod
......
...@@ -17,9 +17,10 @@ from sglang.test.test_programs import ( ...@@ -17,9 +17,10 @@ from sglang.test.test_programs import (
test_stream, test_stream,
test_tool_use, test_tool_use,
) )
from sglang.test.test_utils import CustomTestCase
class TestOpenAIBackend(unittest.TestCase): class TestOpenAIBackend(CustomTestCase):
instruct_backend = None instruct_backend = None
chat_backend = None chat_backend = None
chat_vision_backend = None chat_vision_backend = None
......
...@@ -22,10 +22,10 @@ from sglang.test.test_programs import ( ...@@ -22,10 +22,10 @@ from sglang.test.test_programs import (
test_stream, test_stream,
test_tool_use, test_tool_use,
) )
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestSRTBackend(unittest.TestCase): class TestSRTBackend(CustomTestCase):
backend = None backend = None
@classmethod @classmethod
......
...@@ -3,9 +3,10 @@ import unittest ...@@ -3,9 +3,10 @@ import unittest
import sglang as sgl import sglang as sgl
from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
from sglang.test.test_utils import CustomTestCase
class TestTracing(unittest.TestCase): class TestTracing(CustomTestCase):
def test_few_shot_qa(self): def test_few_shot_qa(self):
@sgl.function @sgl.function
def few_shot_qa(s, question): def few_shot_qa(s, question):
......
...@@ -10,9 +10,10 @@ from sglang.test.test_programs import ( ...@@ -10,9 +10,10 @@ from sglang.test.test_programs import (
test_parallel_encoding, test_parallel_encoding,
test_stream, test_stream,
) )
from sglang.test.test_utils import CustomTestCase
class TestVertexAIBackend(unittest.TestCase): class TestVertexAIBackend(CustomTestCase):
backend = None backend = None
@classmethod @classmethod
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import torch import torch
from sglang.test.runners import HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase
LORA_SETS = [ LORA_SETS = [
# { # {
...@@ -70,7 +71,7 @@ What do you know about llamas? ...@@ -70,7 +71,7 @@ What do you know about llamas?
# PROMPTS.append(sample[0]["content"][:2000]) # PROMPTS.append(sample[0]["content"][:2000])
class TestLoRA(unittest.TestCase): class TestLoRA(CustomTestCase):
def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing inference =======================") print("=================== testing inference =======================")
......
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l, is_in_ci from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
CI_LORA_MODELS = [ CI_LORA_MODELS = [
LoRAModelCase( LoRAModelCase(
...@@ -67,7 +67,7 @@ PROMPTS = [ ...@@ -67,7 +67,7 @@ PROMPTS = [
] ]
class TestLoRABackend(unittest.TestCase): class TestLoRABackend(CustomTestCase):
def run_backend( def run_backend(
self, self,
prompt: str, prompt: str,
......
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l, is_in_ci from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
CI_LORA_MODELS = [ CI_LORA_MODELS = [
LoRAModelCase( LoRAModelCase(
...@@ -69,7 +69,7 @@ PROMPTS = [ ...@@ -69,7 +69,7 @@ PROMPTS = [
BACKEND = "triton" BACKEND = "triton"
class TestLoRATP(unittest.TestCase): class TestLoRATP(CustomTestCase):
def run_tp( def run_tp(
self, self,
prompt: str, prompt: str,
......
...@@ -19,7 +19,7 @@ from typing import List ...@@ -19,7 +19,7 @@ from typing import List
import torch import torch
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.test_utils import is_in_ci from sglang.test.test_utils import CustomTestCase, is_in_ci
MULTI_LORA_MODELS = [ MULTI_LORA_MODELS = [
LoRAModelCase( LoRAModelCase(
...@@ -51,7 +51,7 @@ PROMPTS = [ ...@@ -51,7 +51,7 @@ PROMPTS = [
] ]
class TestMultiLoRABackend(unittest.TestCase): class TestMultiLoRABackend(CustomTestCase):
def run_backend_batch( def run_backend_batch(
self, self,
prompts: List[str], prompts: List[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment