test_chunked_prefill.py 2.32 KB
Newer Older
1
2
3
4
5
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
6
7
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
8
9
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
10
    popen_launch_server,
Ke Bao's avatar
Ke Bao committed
11
    run_bench_serving,
12
)
13
14


15
class TestChunkedPrefill(unittest.TestCase):
16
17
18
19
    def run_mmlu(
        self, disable_radix_cache, enable_mixed_chunk, chunked_prefill_size=32
    ):
        other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
20
21
22
        if disable_radix_cache:
            other_args += ["--disable-radix-cache"]

23
24
25
        if enable_mixed_chunk:
            other_args += ["--enable-mixed-chunk"]

26
        model = DEFAULT_MODEL_NAME_FOR_TEST
27
        base_url = DEFAULT_URL_FOR_TEST
28
29
30
        process = popen_launch_server(
            model,
            base_url,
31
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
32
            other_args=other_args,
33
34
35
        )

        args = SimpleNamespace(
36
37
            base_url=base_url,
            model=model,
38
            eval_name="mmlu",
39
            num_examples=64,
40
            num_threads=32,
41
42
        )

43
44
        try:
            metrics = run_eval(args)
45
            assert metrics["score"] >= 0.65
46
47
48
49
        finally:
            kill_child_process(process.pid)

    def test_chunked_prefill(self):
50
51
52
53
        self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=False)

    def test_mixed_chunked_prefill(self):
        self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=True)
54
55

    def test_chunked_prefill_without_radix_cache(self):
56
57
58
59
        self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=False)

    def test_mixed_chunked_prefill_without_radix_cache(self):
        self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True)
60

61
62
63
64
65
    def test_no_chunked_prefill(self):
        self.run_mmlu(
            disable_radix_cache=False, enable_mixed_chunk=False, chunked_prefill_size=-1
        )

Ke Bao's avatar
Ke Bao committed
66
67
68
69
70
71
72
73
74
75
    def test_no_chunked_prefill_without_radix_cache(self):
        res = run_bench_serving(
            model=DEFAULT_MODEL_NAME_FOR_TEST,
            num_prompts=10,
            request_rate=float("inf"),
            other_server_args=["--disable-radix-cache", "--chunked-prefill-size", "-1"],
        )

        assert res["completed"] == 10

76
77

if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    unittest.main()