Unverified Commit 05b3bf5e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Crash the server on warnings in CI (#1772)

parent 3f5ac88d
import logging import logging
import os
from typing import Union from typing import Union
import torch import torch
...@@ -17,6 +18,11 @@ if is_flashinfer_available(): ...@@ -17,6 +18,11 @@ if is_flashinfer_available():
top_p_renorm_prob, top_p_renorm_prob,
) )
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -36,6 +42,7 @@ class Sampler(nn.Module): ...@@ -36,6 +42,7 @@ class Sampler(nn.Module):
logits = logits.contiguous() logits = logits.contiguous()
if self.use_nan_detectioin and torch.any(torch.isnan(logits)): if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
exit(1) if crash_on_warning else None
logger.warning("Detected errors during sampling! NaN in the logits.") logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where( logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits torch.isnan(logits), torch.full_like(logits, -1e5), logits
......
...@@ -116,7 +116,7 @@ class CudaGraphRunner: ...@@ -116,7 +116,7 @@ class CudaGraphRunner:
if self.model_runner.server_args.disable_cuda_graph_padding: if self.model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128] self.capture_bs = list(range(1, 32)) + [64, 128]
else: else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [ self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
] ]
......
"""
Usage:
python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
"""
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -32,12 +37,12 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -32,12 +37,12 @@ class TestEvalAccuracyLarge(unittest.TestCase):
base_url=self.base_url, base_url=self.base_url,
model=self.model, model=self.model,
eval_name="mmlu", eval_name="mmlu",
num_examples=3000, num_examples=5000,
num_threads=1024, num_threads=1024,
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.705, f"{metrics}" assert metrics["score"] >= 0.71, f"{metrics}"
def test_human_eval(self): def test_human_eval(self):
args = SimpleNamespace( args = SimpleNamespace(
......
"""
Usage:
python -m unittest test_moe_eval_accuracy_large.TestMoEEvalAccuracyLarge.test_mmlu
"""
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -11,7 +16,7 @@ from sglang.test.test_utils import ( ...@@ -11,7 +16,7 @@ from sglang.test.test_utils import (
) )
class TestEvalAccuracyLarge(unittest.TestCase): class TestMoEEvalAccuracyLarge(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST
...@@ -37,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -37,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
base_url=self.base_url, base_url=self.base_url,
model=self.model, model=self.model,
eval_name="mmlu", eval_name="mmlu",
num_examples=3000, num_examples=5000,
num_threads=1024, num_threads=1024,
) )
......
import json
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment