test_nightly_gsm8k_eval.py 3.27 KB
Newer Older
1
2
3
4
5
6
7
8
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1,
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2,
9
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1,
10
11
12
13
14
15
16
17
18
19
20
21
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1,
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)


def parse_models(model_string):
    return [model.strip() for model in model_string.split(",") if model.strip()]


22
23
24
25
def launch_server(base_url, model, is_fp8, is_tp2):
    other_args = ["--log-level-http", "warning", "--trust-remote-code"]
    if is_fp8:
        if "Llama-3" in model or "gemma-2" in model:
26
            # compressed-tensors
27
28
            other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
        elif "Qwen2-72B-Instruct-FP8" in model:
29
            # bug
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
            other_args.extend(["--quantization", "fp8"])
        else:
            other_args.extend(["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"])
    if is_tp2:
        other_args.extend(["--tp", "2"])
    if "DeepSeek" in model:
        other_args.extend(["--mem-frac", "0.85"])
    if "AWQ" in model:
        other_args.extend(["--quantization", "awq"])
    elif "GPTQ" in model:
        other_args.extend(["--quantization", "gptq"])

    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_args,
    )
    return process


51
52
53
54
55
56
57
58
class TestEvalAccuracyLarge(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.model_groups = [
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1), False, False),
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True),
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False),
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True),
59
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False),
60
61
62
63
64
65
66
67
        ]
        cls.base_url = DEFAULT_URL_FOR_TEST

    def setUp(self):
        self.process = None

    def tearDown(self):
        if self.process:
Lianmin Zheng's avatar
Lianmin Zheng committed
68
            kill_child_process(self.process.pid, include_self=True)
69
70
71
72
73

    def test_mgsm_en_all_models(self):
        for model_group, is_fp8, is_tp2 in self.model_groups:
            for model in model_group:
                with self.subTest(model=model):
74
                    self.process = launch_server(self.base_url, model, is_fp8, is_tp2)
75
76
77
78
79
80
81
82
83
84
85
86
87

                    args = SimpleNamespace(
                        base_url=self.base_url,
                        model=model,
                        eval_name="mgsm_en",
                        num_examples=None,
                        num_threads=1024,
                    )

                    metrics = run_eval(args)
                    print(
                        f"{'=' * 42}\n{model} - metrics={metrics} score={metrics['score']}\n{'=' * 42}\n"
                    )
88
89
                    # loosely threshold
                    assert metrics["score"] > 0.5, f"score={metrics['score']} <= 0.5"
90
91
92
93
94
95

                    self.tearDown()


if __name__ == "__main__":
    unittest.main()