test_launch_server.py 5.98 KB
Newer Older
1
import socket
Byron Hsu's avatar
Byron Hsu committed
2
3
4
5
6
7
8
import subprocess
import time
import unittest
from types import SimpleNamespace

import requests

9
from sglang.srt.utils import kill_process_tree
Byron Hsu's avatar
Byron Hsu committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
)


def popen_launch_router(
    model: str,
    base_url: str,
    dp_size: int,
    timeout: float,
):
    """
    Launch the router server process.

    Args:
        model: Model path/name
        base_url: Server base URL
        dp_size: Data parallel size
        timeout: Server launch timeout
    """
    _, host, port = base_url.split(":")
    host = host[2:]

    command = [
        "python3",
        "-m",
        "sglang_router.launch_server",
        "--model-path",
        model,
        "--host",
        host,
        "--port",
        port,
        "--dp",
        str(dp_size),  # Convert dp_size to string
48
49
        "--router-eviction-interval",
        "5",  # frequent eviction for testing
Byron Hsu's avatar
Byron Hsu committed
50
51
52
53
54
    ]

    # Use current environment
    env = None

55
    process = subprocess.Popen(command, stdout=None, stderr=None)
Byron Hsu's avatar
Byron Hsu committed
56
57
58
59
60
61
62

    start_time = time.time()
    with requests.Session() as session:
        while time.time() - start_time < timeout:
            try:
                response = session.get(f"{base_url}/health")
                if response.status_code == 200:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
                    print(f"Router {base_url} is healthy")
                    return process
            except requests.RequestException:
                pass
            time.sleep(10)

    raise TimeoutError("Router failed to start within the timeout period.")


def find_available_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


def popen_launch_server(
    model: str,
    base_url: str,
    timeout: float,
):
    _, host, port = base_url.split(":")
    host = host[2:]

    command = [
        "python3",
        "-m",
        "sglang.launch_server",
        "--model-path",
        model,
        "--host",
        host,
        "--port",
        port,
        "--base-gpu-id",
        "1",
    ]

    process = subprocess.Popen(command, stdout=None, stderr=None)

    start_time = time.time()
    with requests.Session() as session:
        while time.time() - start_time < timeout:
            try:
                response = session.get(f"{base_url}/health")
                if response.status_code == 200:
                    print(f"Server {base_url} is healthy")
Byron Hsu's avatar
Byron Hsu committed
109
110
111
112
113
114
115
116
                    return process
            except requests.RequestException:
                pass
            time.sleep(10)

    raise TimeoutError("Server failed to start within the timeout period.")


117
class TestLaunchServer(unittest.TestCase):
Byron Hsu's avatar
Byron Hsu committed
118
119
120
121
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
122
        cls.process = None
123
        cls.other_process = []
Byron Hsu's avatar
Byron Hsu committed
124
125
126

    @classmethod
    def tearDownClass(cls):
127
        kill_process_tree(cls.process.pid)
128
129
        for process in cls.other_process:
            kill_process_tree(process.pid)
Byron Hsu's avatar
Byron Hsu committed
130
131

    def test_mmlu(self):
132
133
134
135
136
137
138
139
        # DP size = 2
        TestLaunchServer.process = popen_launch_router(
            self.model,
            self.base_url,
            dp_size=2,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        )

Byron Hsu's avatar
Byron Hsu committed
140
141
142
143
144
145
146
147
148
149
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mmlu",
            num_examples=64,
            num_threads=32,
            temperature=0.1,
        )

        metrics = run_eval(args)
Byron Hsu's avatar
Byron Hsu committed
150
151
152
153
154
        score = metrics["score"]
        THRESHOLD = 0.65
        passed = score >= THRESHOLD
        msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
        self.assertGreaterEqual(score, THRESHOLD, msg)
Byron Hsu's avatar
Byron Hsu committed
155

156
157
158
159
160
161
162
163
    def test_add_and_remove_worker(self):
        # DP size = 1
        TestLaunchServer.process = popen_launch_router(
            self.model,
            self.base_url,
            dp_size=1,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        )
164
165
166
167
168
169
        # 1. start a worker, and wait until it is healthy
        port = find_available_port()
        worker_url = f"http://127.0.0.1:{port}"
        worker_process = popen_launch_server(
            self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
        )
170
        TestLaunchServer.other_process.append(worker_process)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        # 2. use /add_worker api to add it the the router
        with requests.Session() as session:
            response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
            print(f"status code: {response.status_code}, response: {response.text}")
            self.assertEqual(response.status_code, 200)
        # 3. run mmlu
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mmlu",
            num_examples=64,
            num_threads=32,
            temperature=0.1,
        )
        metrics = run_eval(args)
        score = metrics["score"]
        THRESHOLD = 0.65
        passed = score >= THRESHOLD
        msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
        self.assertGreaterEqual(score, THRESHOLD, msg)

192
193
194
195
196
197
198
199
200
201
202
203
204
205
        # 4. use /remove_worker api to remove it from the router
        with requests.Session() as session:
            response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
            print(f"status code: {response.status_code}, response: {response.text}")
            self.assertEqual(response.status_code, 200)

        # 5. run mmlu again
        metrics = run_eval(args)
        score = metrics["score"]
        THRESHOLD = 0.65
        passed = score >= THRESHOLD
        msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
        self.assertGreaterEqual(score, THRESHOLD, msg)

Byron Hsu's avatar
Byron Hsu committed
206
207
208

if __name__ == "__main__":
    unittest.main()