"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "03c039c48ee350bb7513584bfe23fd58bd016a7e"
test_disaggregation_basic.py 13.1 KB
Newer Older
1
import json
2
import os
3
4
5
import unittest
from types import SimpleNamespace

6
import openai
7
import requests
8
from transformers import AutoTokenizer
9
10

from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
11
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
12
from sglang.test.test_utils import (
Byron Hsu's avatar
Byron Hsu committed
13
14
    DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
    DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
15
16
17
18
19
20
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    popen_launch_pd_server,
)


21
class TestDisaggregationAccuracy(TestDisaggregationBase):
22
23
    @classmethod
    def setUpClass(cls):
24
        super().setUpClass()
25
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
26
27
28
29
30
31

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
32
33
34
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

35
        cls.launch_lb()
36
37
38
39
40
41
42
43

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
44
            "1",
45
        ]
46
        prefill_args += cls.transfer_backend + cls.rdma_devices
47
48
49
50
51
52
53
54
55
56
57
58
59
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
60
            "--tp",
61
            "1",
62
            "--base-gpu-id",
63
            "1",
64
        ]
65
        decode_args += cls.transfer_backend + cls.rdma_devices
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.62)

    def test_logprob(self):
89
        prompt = "The capital of france is "
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        response = requests.post(
            self.lb_url + "/generate",
            json={
                "text": prompt,
                "sampling_params": {"temperature": 0},
                "return_logprob": True,
                "return_input_logprob": True,
                "logprob_start_len": 0,
            },
        )

        j = response.json()
        completion_tokens = j["meta_info"]["completion_tokens"]
        input_logprobs = j["meta_info"]["input_token_logprobs"]
        output_logprobs = j["meta_info"]["output_token_logprobs"]

        assert (
            len(output_logprobs) == completion_tokens
        ), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}"
        assert (
            len(input_logprobs) > 0
        ), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    def test_structured_output(self):
        json_schema = json.dumps(
            {
                "type": "object",
                "properties": {
                    "name": {"type": "string", "pattern": "^[\\w]+$"},
                    "population": {"type": "integer"},
                },
                "required": ["name", "population"],
            }
        )

        # JSON
        response = requests.post(
            f"{self.lb_url}/generate",
            json={
                "text": "Here is the information of the capital of France in the JSON format.\n",
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 64,
                    "json_schema": json_schema,
                },
            },
        )
        output = response.json()["text"]
        # ensure the output is a valid JSON
        json.loads(output)

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def test_first_token_finish(self):
        client = openai.Client(api_key="empty", base_url=f"{self.lb_url}/v1")
        tokenizer = AutoTokenizer.from_pretrained(self.model)
        eos_token = tokenizer.eos_token_id
        prompt = "The best programming language for AI is"

        # First token EOS
        res = client.completions.create(
            model="dummy", prompt=prompt, logit_bias={eos_token: 42}
        ).model_dump()
        print(f"{res=}")

        assert res["usage"]["completion_tokens"] == 1, (
            "Expected completion_tokens to be 1 when first token is EOS, "
            f"but got {res['usage']['completion_tokens']}"
        )

        # First token EOS with ignore_eos
        res = client.completions.create(
            model="dummy",
            prompt=prompt,
            logit_bias={eos_token: 42},
            extra_body={"ignore_eos": True},
        ).model_dump()
        print(f"{res=}")

        assert res["usage"]["completion_tokens"] > 1, (
            "Expected completion_tokens to be greater than 1 when ignore_eos is True, "
            f"but got {res['usage']['completion_tokens']}"
        )

        # First token with specified stop token
        stop_token_id = tokenizer.encode(" hello", add_special_tokens=False)[0]
        res = client.completions.create(
            model="dummy",
            prompt=prompt,
            logit_bias={stop_token_id: 42},
            stop=[" hello"],
        ).model_dump()
        print(f"{res=}")

        assert res["usage"]["completion_tokens"] == 1, (
            "Expected completion_tokens to be 1 when first token is stop token, "
            f"but got {res['usage']['completion_tokens']}"
        )

187

188
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
189
190
    @classmethod
    def setUpClass(cls):
191
        super().setUpClass()
192
193
194
195
196
197
198
199
200
201
202
203
204
        # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
        os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"

        cls.model = DEFAULT_MODEL_NAME_FOR_TEST

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

205
        cls.launch_lb()
206

207
208
209
210
    @classmethod
    def tearDownClass(cls):
        os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
        super().tearDownClass()
211
212
213
214
215
216
217

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
218
            "--tp",
219
            "1",
220
        ]
221
        prefill_args += cls.transfer_backend + cls.rdma_devices
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
236
            "1",
237
            "--base-gpu-id",
238
            "1",
239
        ]
240
        decode_args += cls.transfer_backend + cls.rdma_devices
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
255
256
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
257
        )
258

259
        # Expect lots of failure but the server cannot crash
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        try:
            metrics = run_eval_few_shot_gsm8k(args)
            print(f"Evaluation metrics: {metrics}")
        except Exception as e:
            print(f"Test encountered expected errors: {e}")
            # Check if servers are still healthy
            try:
                response = requests.get(self.prefill_url + "/health_generate")
                assert response.status_code == 200
                response = requests.get(self.decode_url + "/health_generate")
                assert response.status_code == 200
            except Exception as health_check_error:
                # If health check fails, re-raise the original exception
                raise e from health_check_error
274
275


276
class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
Byron Hsu's avatar
Byron Hsu committed
277
278
279

    @classmethod
    def setUpClass(cls):
280
        super().setUpClass()
Byron Hsu's avatar
Byron Hsu committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
        cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
        cls.spec_args = [
            "--speculative-algorithm",
            "EAGLE",
            "--speculative-draft-model-path",
            cls.draft_model,
            "--speculative-num-steps",
            "3",
            "--speculative-eagle-topk",
            "4",
            "--speculative-num-draft-tokens",
            "16",
            "--cuda-graph-max-bs",
            "8",
        ]
297
        print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
Byron Hsu's avatar
Byron Hsu committed
298

299
300
301
        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()
Byron Hsu's avatar
Byron Hsu committed
302

303
        # Block until both
Byron Hsu's avatar
Byron Hsu committed
304
305
306
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

307
        cls.launch_lb()
Byron Hsu's avatar
Byron Hsu committed
308
309
310
311
312
313
314
315

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
316
            "1",
Byron Hsu's avatar
Byron Hsu committed
317
        ] + cls.spec_args
318
        prefill_args += cls.transfer_backend + cls.rdma_devices
Byron Hsu's avatar
Byron Hsu committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
333
            "1",
Byron Hsu's avatar
Byron Hsu committed
334
            "--base-gpu-id",
335
            "1",
Byron Hsu's avatar
Byron Hsu committed
336
        ] + cls.spec_args
337
        decode_args += cls.transfer_backend + cls.rdma_devices
Byron Hsu's avatar
Byron Hsu committed
338
339
340
341
342
343
344
345
346
347
348
349
350
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
351
352
353
            parallel=2,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
Byron Hsu's avatar
Byron Hsu committed
354
355
356
357
358
359
360
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.20)


361
class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
362
363
    @classmethod
    def setUpClass(cls):
364
        super().setUpClass()
365
366
367
368
369
370
371
372
373
374
375
        os.environ["SGLANG_TEST_RETRACT"] = "true"
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST

        # Non blocking start servers
        cls.start_prefill()
        cls.start_decode()

        # Block until both
        cls.wait_server_ready(cls.prefill_url + "/health")
        cls.wait_server_ready(cls.decode_url + "/health")

376
        cls.launch_lb()
377

378
379
380
381
    @classmethod
    def tearDownClass(cls):
        os.environ.pop("SGLANG_TEST_RETRACT")
        super().tearDownClass()
382
383
384
385
386
387
388
389

    @classmethod
    def start_prefill(cls):
        prefill_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "prefill",
            "--tp",
390
            "1",
391
        ]
392
        prefill_args += cls.transfer_backend + cls.rdma_devices
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        cls.process_prefill = popen_launch_pd_server(
            cls.model,
            cls.prefill_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=prefill_args,
        )

    @classmethod
    def start_decode(cls):
        decode_args = [
            "--trust-remote-code",
            "--disaggregation-mode",
            "decode",
            "--tp",
407
            "1",
408
            "--base-gpu-id",
409
            "1",
410
        ]
411
        decode_args += cls.transfer_backend + cls.rdma_devices
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        cls.process_decode = popen_launch_pd_server(
            cls.model,
            cls.decode_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=decode_args,
        )

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host=f"http://{self.base_host}",
            port=int(self.lb_port),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"Evaluation metrics: {metrics}")

        self.assertGreater(metrics["accuracy"], 0.62)


435
436
if __name__ == "__main__":
    unittest.main()