test_serving_throughput.py 3.44 KB
Newer Older
1
import os
2
3
4
5
import unittest
from types import SimpleNamespace

from sglang.bench_serving import run_benchmark
Lianmin Zheng's avatar
Lianmin Zheng committed
6
from sglang.srt.server_args import ServerArgs
7
from sglang.srt.utils import kill_child_process
Yineng Zhang's avatar
Yineng Zhang committed
8
9
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
10
11
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
Yineng Zhang's avatar
Yineng Zhang committed
12
13
    popen_launch_server,
)
14
15
16


class TestServingThroughput(unittest.TestCase):
17
    def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size):
18
19
20
21
        # Launch the server
        other_args = []
        if disable_radix_cache:
            other_args.append("--disable-radix-cache")
22
23
        if attention_backend:
            other_args.extend(["--attention-backend", attention_backend])
24
25
26
        other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])

        model = DEFAULT_MODEL_NAME_FOR_TEST
27
        base_url = DEFAULT_URL_FOR_TEST
28
        process = popen_launch_server(
29
30
31
32
            model,
            base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=other_args,
33
34
35
        )

        # Run benchmark
36
        num_prompts = 500
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        args = SimpleNamespace(
            backend="sglang",
            base_url=base_url,
            host=None,
            port=None,
            dataset_name="random",
            dataset_path="",
            model=None,
            tokenizer=None,
            num_prompts=num_prompts,
            sharegpt_output_len=None,
            random_input_len=4096,
            random_output_len=2048,
            random_range_ratio=0.0,
            request_rate=float("inf"),
            multi=None,
            seed=0,
            output_file=None,
            disable_tqdm=False,
            disable_stream=False,
            disable_ignore_eos=False,
            extra_request_body=None,
        )

        try:
            res = run_benchmark(args)
        finally:
            kill_child_process(process.pid)

        assert res["completed"] == num_prompts
67
        return res
68
69

    def test_default(self):
70
        res = self.run_test(
Lianmin Zheng's avatar
Lianmin Zheng committed
71
            disable_radix_cache=ServerArgs.disable_radix_cache,
72
            attention_backend=ServerArgs.attention_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
            chunked_prefill_size=ServerArgs.chunked_prefill_size,
74
75
        )

76
        if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
77
            assert res["output_throughput"] > 2400
78

79
    def test_default_without_radix_cache(self):
80
        res = self.run_test(
81
            disable_radix_cache=True,
82
            attention_backend=ServerArgs.attention_backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
83
            chunked_prefill_size=ServerArgs.chunked_prefill_size,
84
85
        )

86
        if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
87
            assert res["output_throughput"] > 2800
88

89
    def test_default_without_chunked_prefill(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
        res = self.run_test(
            disable_radix_cache=ServerArgs.disable_radix_cache,
92
            attention_backend=ServerArgs.attention_backend,
93
            chunked_prefill_size=-1,
94
95
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
96
        if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
97
            assert res["output_throughput"] > 2400
98

99
100
101
102
103
104
105
106
107
108
    def test_default_with_triton_attention_backend(self):
        res = self.run_test(
            disable_radix_cache=ServerArgs.disable_radix_cache,
            attention_backend="triton",
            chunked_prefill_size=-1,
        )

        if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
            assert res["output_throughput"] > 2400

109
110
111

if __name__ == "__main__":
    unittest.main()