test_disaggregation_basic.py 11.4 KB
Newer Older
1
import json
2
import os
3
4
5
6
7
8
9
import time
import unittest
from types import SimpleNamespace

import requests

from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
10
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
11
from sglang.test.test_utils import (
Byron Hsu's avatar
Byron Hsu committed
12
13
    DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
    DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
14
15
16
17
18
19
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    popen_launch_pd_server,
)


20
class TestDisaggregationAccuracy(TestDisaggregationBase):
21
22
    @classmethod
    def setUpClass(cls):
23
        super().setUpClass()
24
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
25
26
27
28
29
30

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
31
32
33
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

34
        cls.launch_lb()
35
36
37
38
39
40
41
42

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
43
            "1",
44
        ]
45
        prefill_args += cls.transfer_backend + cls.rdma_devices
46
47
48
49
50
51
52
53
54
55
56
57
58
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
59
            "--tp",
60
            "1",
61
            "--base-gpu-id",
62
            "1",
63
        ]
64
        decode_args += cls.transfer_backend + cls.rdma_devices
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.62)

    def test_logprob(self):
88
        prompt = "The capital of france is "
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        response = requests.post(
            self.lb_url + "/generate",
            json={
                "text": prompt,
                "sampling_params": {"temperature": 0},
                "return_logprob": True,
                "return_input_logprob": True,
                "logprob_start_len": 0,
            },
        )

        j = response.json()
        completion_tokens = j["meta_info"]["completion_tokens"]
        input_logprobs = j["meta_info"]["input_token_logprobs"]
        output_logprobs = j["meta_info"]["output_token_logprobs"]

        assert (
            len(output_logprobs) == completion_tokens
        ), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}"
        assert (
            len(input_logprobs) > 0
        ), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    def test_structured_output(self):
        json_schema = json.dumps(
            {
                "type": "object",
                "properties": {
                    "name": {"type": "string", "pattern": "^[\\w]+$"},
                    "population": {"type": "integer"},
                },
                "required": ["name", "population"],
            }
        )

        # JSON
        response = requests.post(
            f"{self.lb_url}/generate",
            json={
                "text": "Here is the information of the capital of France in the JSON format.\n",
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 64,
                    "json_schema": json_schema,
                },
            },
        )
        output = response.json()["text"]
        # ensure the output is a valid JSON
        json.loads(output)

140

141
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
142
143
    @classmethod
    def setUpClass(cls):
144
        super().setUpClass()
145
146
147
148
149
150
151
152
153
154
155
156
157
        # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
        os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"

        cls.model = DEFAULT_MODEL_NAME_FOR_TEST

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

158
        cls.launch_lb()
159

160
161
162
163
    @classmethod
    def tearDownClass(cls):
        os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
        super().tearDownClass()
164
165
166
167
168
169
170

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
171
            "--tp",
172
            "1",
173
        ]
174
        prefill_args += cls.transfer_backend + cls.rdma_devices
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
189
            "1",
190
            "--base-gpu-id",
191
            "1",
192
        ]
193
        decode_args += cls.transfer_backend + cls.rdma_devices
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
208
209
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
210
        )
211

212
        # Expect lots of failure but the server cannot crash
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        try:
            metrics = run_eval_few_shot_gsm8k(args)
            print(f"Evaluation metrics: {metrics}")
        except Exception as e:
            print(f"Test encountered expected errors: {e}")
            # Check if servers are still healthy
            try:
                response = requests.get(self.prefill_url + "/health_generate")
                assert response.status_code == 200
                response = requests.get(self.decode_url + "/health_generate")
                assert response.status_code == 200
            except Exception as health_check_error:
                # If health check fails, re-raise the original exception
                raise e from health_check_error
227
228


229
class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
Byron Hsu's avatar
Byron Hsu committed
230
231
232

    @classmethod
    def setUpClass(cls):
233
        super().setUpClass()
Byron Hsu's avatar
Byron Hsu committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
        cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
        cls.spec_args = [
            "--speculative-algorithm",
            "EAGLE",
            "--speculative-draft-model-path",
            cls.draft_model,
            "--speculative-num-steps",
            "3",
            "--speculative-eagle-topk",
            "4",
            "--speculative-num-draft-tokens",
            "16",
            "--cuda-graph-max-bs",
            "8",
        ]
250
        print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
Byron Hsu's avatar
Byron Hsu committed
251

252
253
254
        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()
Byron Hsu's avatar
Byron Hsu committed
255

256
        # Block until both
Byron Hsu's avatar
Byron Hsu committed
257
258
259
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

260
        cls.launch_lb()
Byron Hsu's avatar
Byron Hsu committed
261
262
263
264
265
266
267
268

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
269
            "1",
Byron Hsu's avatar
Byron Hsu committed
270
        ] + cls.spec_args
271
        prefill_args += cls.transfer_backend + cls.rdma_devices
Byron Hsu's avatar
Byron Hsu committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
286
            "1",
Byron Hsu's avatar
Byron Hsu committed
287
            "--base-gpu-id",
288
            "1",
Byron Hsu's avatar
Byron Hsu committed
289
        ] + cls.spec_args
290
        decode_args += cls.transfer_backend + cls.rdma_devices
Byron Hsu's avatar
Byron Hsu committed
291
292
293
294
295
296
297
298
299
300
301
302
303
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
304
305
306
            parallel=2,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
Byron Hsu's avatar
Byron Hsu committed
307
308
309
310
311
312
313
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.20)


314
class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
315
316
    @classmethod
    def setUpClass(cls):
317
        super().setUpClass()
318
319
320
321
322
323
324
325
326
327
328
        os.environ["SGLANG_TEST_RETRACT"] = "true"
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

329
        cls.launch_lb()
330

331
332
333
334
    @classmethod
    def tearDownClass(cls):
        os.environ.pop("SGLANG_TEST_RETRACT")
        super().tearDownClass()
335
336
337
338
339
340
341
342

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
343
            "1",
344
        ]
345
        prefill_args += cls.transfer_backend + cls.rdma_devices
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
360
            "1",
361
            "--base-gpu-id",
362
            "1",
363
        ]
364
        decode_args += cls.transfer_backend + cls.rdma_devices
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.62)


388
389
if __name__ == "__main__":
    unittest.main()