"vscode:/vscode.git/clone" did not exist on "b2388433be8fc16e7b19e338a253f99ac70e4a31"
test_disaggregation_basic.py 11.4 KB
Newer Older
1
import json
2
import os
3
4
5
6
7
8
import unittest
from types import SimpleNamespace

import requests

from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
9
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
10
from sglang.test.test_utils import (
Byron Hsu's avatar
Byron Hsu committed
11
12
    DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
    DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
13
14
15
16
17
18
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    popen_launch_pd_server,
)


19
class TestDisaggregationAccuracy(TestDisaggregationBase):
20
21
    @classmethod
    def setUpClass(cls):
22
        super().setUpClass()
23
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
24
25
26
27
28
29

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
30
31
32
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

33
        cls.launch_lb()
34
35
36
37
38
39
40
41

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
42
            "1",
43
        ]
44
        prefill_args += cls.transfer_backend + cls.rdma_devices
45
46
47
48
49
50
51
52
53
54
55
56
57
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
58
            "--tp",
59
            "1",
60
            "--base-gpu-id",
61
            "1",
62
        ]
63
        decode_args += cls.transfer_backend + cls.rdma_devices
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.62)

    def test_logprob(self):
87
        prompt = "The capital of france is "
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        response = requests.post(
            self.lb_url + "/generate",
            json={
                "text": prompt,
                "sampling_params": {"temperature": 0},
                "return_logprob": True,
                "return_input_logprob": True,
                "logprob_start_len": 0,
            },
        )

        j = response.json()
        completion_tokens = j["meta_info"]["completion_tokens"]
        input_logprobs = j["meta_info"]["input_token_logprobs"]
        output_logprobs = j["meta_info"]["output_token_logprobs"]

        assert (
            len(output_logprobs) == completion_tokens
        ), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}"
        assert (
            len(input_logprobs) > 0
        ), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def test_structured_output(self):
        json_schema = json.dumps(
            {
                "type": "object",
                "properties": {
                    "name": {"type": "string", "pattern": "^[\\w]+$"},
                    "population": {"type": "integer"},
                },
                "required": ["name", "population"],
            }
        )

        # JSON
        response = requests.post(
            f"{self.lb_url}/generate",
            json={
                "text": "Here is the information of the capital of France in the JSON format.\n",
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 64,
                    "json_schema": json_schema,
                },
            },
        )
        output = response.json()["text"]
        # ensure the output is a valid JSON
        json.loads(output)

139

140
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
141
142
    @classmethod
    def setUpClass(cls):
143
        super().setUpClass()
144
145
146
147
148
149
150
151
152
153
154
155
156
        # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
        os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"

        cls.model = DEFAULT_MODEL_NAME_FOR_TEST

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

157
        cls.launch_lb()
158

159
160
161
162
    @classmethod
    def tearDownClass(cls):
        os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
        super().tearDownClass()
163
164
165
166
167
168
169

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
170
            "--tp",
171
            "1",
172
        ]
173
        prefill_args += cls.transfer_backend + cls.rdma_devices
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
188
            "1",
189
            "--base-gpu-id",
190
            "1",
191
        ]
192
        decode_args += cls.transfer_backend + cls.rdma_devices
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
207
208
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
209
        )
210

211
        # Expect lots of failure but the server cannot crash
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        try:
            metrics = run_eval_few_shot_gsm8k(args)
            print(f"Evaluation metrics: {metrics}")
        except Exception as e:
            print(f"Test encountered expected errors: {e}")
            # Check if servers are still healthy
            try:
                response = requests.get(self.prefill_url + "/health_generate")
                assert response.status_code == 200
                response = requests.get(self.decode_url + "/health_generate")
                assert response.status_code == 200
            except Exception as health_check_error:
                # If health check fails, re-raise the original exception
                raise e from health_check_error
226
227


228
class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
Byron Hsu's avatar
Byron Hsu committed
229
230
231

    @classmethod
    def setUpClass(cls):
232
        super().setUpClass()
Byron Hsu's avatar
Byron Hsu committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
        cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
        cls.spec_args = [
            "--speculative-algorithm",
            "EAGLE",
            "--speculative-draft-model-path",
            cls.draft_model,
            "--speculative-num-steps",
            "3",
            "--speculative-eagle-topk",
            "4",
            "--speculative-num-draft-tokens",
            "16",
            "--cuda-graph-max-bs",
            "8",
        ]
249
        print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
Byron Hsu's avatar
Byron Hsu committed
250

251
252
253
        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()
Byron Hsu's avatar
Byron Hsu committed
254

255
        # Block until both
Byron Hsu's avatar
Byron Hsu committed
256
257
258
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

259
        cls.launch_lb()
Byron Hsu's avatar
Byron Hsu committed
260
261
262
263
264
265
266
267

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
268
            "1",
Byron Hsu's avatar
Byron Hsu committed
269
        ] + cls.spec_args
270
        prefill_args += cls.transfer_backend + cls.rdma_devices
Byron Hsu's avatar
Byron Hsu committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
285
            "1",
Byron Hsu's avatar
Byron Hsu committed
286
            "--base-gpu-id",
287
            "1",
Byron Hsu's avatar
Byron Hsu committed
288
        ] + cls.spec_args
289
        decode_args += cls.transfer_backend + cls.rdma_devices
Byron Hsu's avatar
Byron Hsu committed
290
291
292
293
294
295
296
297
298
299
300
301
302
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
303
304
305
            parallel=2,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
Byron Hsu's avatar
Byron Hsu committed
306
307
308
309
310
311
312
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.20)


313
class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
314
315
    @classmethod
    def setUpClass(cls):
316
        super().setUpClass()
317
318
319
320
321
322
323
324
325
326
327
        os.environ["SGLANG_TEST_RETRACT"] = "true"
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

328
        cls.launch_lb()
329

330
331
332
333
    @classmethod
    def tearDownClass(cls):
        os.environ.pop("SGLANG_TEST_RETRACT")
        super().tearDownClass()
334
335
336
337
338
339
340
341

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
342
            "1",
343
        ]
344
        prefill_args += cls.transfer_backend + cls.rdma_devices
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
359
            "1",
360
            "--base-gpu-id",
361
            "1",
362
        ]
363
        decode_args += cls.transfer_backend + cls.rdma_devices
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.62)


387
388
if __name__ == "__main__":
    unittest.main()