test_srt_endpoint.py 28.4 KB
Newer Older
1
2
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
3
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
4
python3 -m unittest test_srt_endpoint.TestTokenizeDetokenize
5
6
"""

7
import json
8
import random
9
import time
10
import unittest
11
from concurrent.futures import ThreadPoolExecutor
12
from functools import partial
13
from typing import Optional
14

15
import numpy as np
16
17
import requests

18
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
19
from sglang.srt.utils import kill_process_tree
20
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
21
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
22
23
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
24
    CustomTestCase,
25
    popen_launch_server,
26
    run_logprob_check,
27
)
28
29


30
class TestSRTEndpoint(CustomTestCase):
31
32
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
33
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
34
35
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
36
37
38
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41
            other_args=(
                "--enable-custom-logit-processor",
                "--mem-fraction-static",
42
43
44
                "0.7",
                "--cuda-graph-max-bs",
                "8",
Lianmin Zheng's avatar
Lianmin Zheng committed
45
            ),
46
        )
47
48
49

    @classmethod
    def tearDownClass(cls):
50
        kill_process_tree(cls.process.pid)
51
52

    def run_decode(
53
54
55
56
57
58
        self,
        return_logprob=False,
        top_logprobs_num=0,
        return_text=False,
        n=1,
        stream=False,
59
        batch=False,
60
    ):
61
62
63
64
65
        if batch:
            text = ["The capital of France is"]
        else:
            text = "The capital of France is"

66
67
68
        response = requests.post(
            self.base_url + "/generate",
            json={
69
                "text": text,
70
71
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
72
                    "max_new_tokens": 16,
73
74
                    "n": n,
                },
75
                "stream": stream,
76
77
78
79
80
81
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "return_text_in_logprobs": return_text,
                "logprob_start_len": 0,
            },
        )
82
83
84
85
86
87
88
        if not stream:
            response_json = response.json()
        else:
            response_json = []
            for line in response.iter_lines():
                if line.startswith(b"data: ") and line[6:] != b"[DONE]":
                    response_json.append(json.loads(line[6:]))
89
90

        print(json.dumps(response_json, indent=2))
91
92
93
94
95
        print("=" * 100)

    def test_simple_decode(self):
        self.run_decode()

96
97
98
    def test_simple_decode_batch(self):
        self.run_decode(batch=True)

99
100
101
    def test_parallel_sample(self):
        self.run_decode(n=3)

102
103
104
    def test_parallel_sample_stream(self):
        self.run_decode(n=3, stream=True)

105
    def test_logprob(self):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        self.run_decode(
            return_logprob=True,
            top_logprobs_num=5,
            return_text=True,
        )

    def test_logprob_start_len(self):
        logprob_start_len = 4
        new_tokens = 4
        prompts = [
            "I have a very good idea on",
            "Today is a sunndy day and",
        ]

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": prompts,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": new_tokens,
                },
                "return_logprob": True,
                "top_logprobs_num": 5,
                "return_text_in_logprobs": True,
                "logprob_start_len": logprob_start_len,
            },
        )
        response_json = response.json()
        print(json.dumps(response_json, indent=2))

        for i, res in enumerate(response_json):
138
139
            self.assertEqual(
                res["meta_info"]["prompt_tokens"],
140
                logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
141
142
143
144
145
            )
            assert prompts[i].endswith(
                "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
            )

146
147
148
149
150
            self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
            self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
            self.assertEqual(
                res["text"],
                "".join([x[-1] for x in res["meta_info"]["output_token_logprobs"]]),
151
            )
152

153
    def test_logprob_with_chunked_prefill(self):
154
        """Test a long prompt that requests output logprobs will not hit OOM."""
155
156
157
158
159
160
161
162
163
164
165
166
167
        new_tokens = 4
        prompts = "I have a very good idea on this. " * 8000

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": prompts,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": new_tokens,
                },
                "return_logprob": True,
                "logprob_start_len": -1,
Lianmin Zheng's avatar
Lianmin Zheng committed
168
                "top_logprobs_num": 5,
169
170
171
            },
        )
        response_json = response.json()
Lianmin Zheng's avatar
Lianmin Zheng committed
172
        # print(json.dumps(response_json, indent=2))
173
174
175

        res = response_json
        self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
176
177

        # Test the number of tokens are correct
178
        self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
182
183
184
185
186
187
        self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens)

        # Test the top-1 tokens are the same as output tokens (because temp = 0.0)
        for i in range(new_tokens):
            self.assertListEqual(
                res["meta_info"]["output_token_logprobs"][i],
                res["meta_info"]["output_top_logprobs"][i][0],
            )
            self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5)
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    def test_logprob_match(self):
        """Test the output logprobs are close to the input logprobs if we run a prefill again."""

        def run_generate(
            prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
        ):

            if isinstance(prompt, str):
                prompt_kwargs = {"text": prompt}
            else:
                prompt_kwargs = {"input_ids": prompt}

            response = requests.post(
                self.base_url + "/generate",
                json={
                    **prompt_kwargs,
                    "sampling_params": {
                        "temperature": 1.0,
                        "max_new_tokens": max_new_tokens,
                        "ignore_eos": True,
                    },
                    "return_logprob": return_logprob,
                    "return_text_in_logprobs": True,
                    "logprob_start_len": logprob_start_len,
                },
            )
            return response.json()

        prompt = "I have a very good idea on how to"

        gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
        output_logprobs = np.array(
            [x[0] for x in gen["meta_info"]["output_token_logprobs"]]
        )
        num_prompts_tokens = gen["meta_info"]["prompt_tokens"]

        input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
        output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]

        new_prompt = input_tokens + output_tokens
        score = run_generate(
            new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
        )
        output_logprobs_score = np.array(
            [
                x[0]
                for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
            ]
        )

        print(f"{output_logprobs[-10:]=}")
        print(f"{output_logprobs_score[-10:]=}")

        diff = np.abs(output_logprobs - output_logprobs_score)
        max_diff = np.max(diff)
244
        self.assertLess(max_diff, 0.35)
Lianmin Zheng's avatar
Lianmin Zheng committed
245
246
247
248
249

    def test_logprob_mixed(self):
        args = []
        temperature = 0
        # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
250
        for input_len in [1000, 5000, 10000, 50000]:
Lianmin Zheng's avatar
Lianmin Zheng committed
251
            for output_len in [4, 8]:
252
                for logprob_start_len in [0, 500, 2500, 5000, 25000]:
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                    for return_logprob in [True, False]:
                        for top_logprobs_num in [0, 5]:

                            if logprob_start_len >= input_len:
                                continue

                            args.append(
                                (
                                    input_len,
                                    output_len,
                                    temperature,
                                    logprob_start_len,
                                    return_logprob,
                                    top_logprobs_num,
                                )
                            )

        random.shuffle(args)

272
        func = partial(run_logprob_check, self)
Lianmin Zheng's avatar
Lianmin Zheng committed
273
        with ThreadPoolExecutor(8) as executor:
274
            list(executor.map(func, args))
Lianmin Zheng's avatar
Lianmin Zheng committed
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
    def test_logprob_grammar(self):
        prompts = "Question: Is Paris the Capital of France? Answer:"
        allowed_tokens = [" Yes", " No"]

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": prompts,
                "sampling_params": {
                    "temperature": 1.0,
                    "max_new_tokens": 1,
                    "regex": "( Yes| No)",
                },
                "return_logprob": True,
290
                "top_logprobs_num": 5,  # The grammar constraint allows all prefix tokens so we need to use a larger top_k.
291
292
293
294
295
296
297
298
                "return_text_in_logprobs": True,
            },
        )
        response_json = response.json()
        output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
        print(f"{output_top_logprobs=}")

        # Parse results
299
        # This is because the grammar constraint allows all prefix tokens
300
301
302
303
304
305
306
307
308
309
310
        logprobs = [None] * 2
        for i in range(len(output_top_logprobs)):
            try:
                idx = allowed_tokens.index(output_top_logprobs[i][2])
            except ValueError:
                # Not found
                continue
            logprobs[idx] = output_top_logprobs[i][0]

        self.assertTrue(all(x is not None for x in logprobs))

311
312
313
314
315
    def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
        """Test custom logit processor with custom params.

        If target_token_id is None, the custom logit processor won't be passed in.
        """
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345

        custom_params = {"token_id": target_token_id}

        class DeterministicLogitProcessor(CustomLogitProcessor):
            """A dummy logit processor that changes the logits to always
            sample the given token id.
            """

            def __call__(self, logits, custom_param_list):
                assert logits.shape[0] == len(custom_param_list)
                key = "token_id"

                for i, param_dict in enumerate(custom_param_list):
                    # Mask all other tokens
                    logits[i, :] = -float("inf")
                    # Assign highest probability to the specified token
                    logits[i, param_dict[key]] = 0.0
                return logits

        prompts = "Question: Is Paris the Capital of France? Answer:"

        # Base case json data to be posted to the server.
        base_json = {
            "text": prompts,
            "sampling_params": {"temperature": 0.0},
            "return_logprob": True,
        }

        # Custom json data with custom logit processor and params.
        custom_json = base_json.copy()
346
347
        # Only set the custom logit processor if target_token_id is not None.
        if target_token_id is not None:
348
            custom_json["custom_logit_processor"] = DeterministicLogitProcessor.to_str()
349
            custom_json["sampling_params"]["custom_params"] = custom_params
350
351
352
353
354
355
356
357
358
359

        custom_response = requests.post(
            self.base_url + "/generate",
            json=custom_json,
        ).json()

        output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
        sampled_tokens = [x[1] for x in output_token_logprobs]

        # The logit processor should always sample the given token as the logits is deterministic.
360
361
362
363
364
365
        if target_token_id is not None:
            self.assertTrue(
                all(x == custom_params["token_id"] for x in sampled_tokens),
                # Print the detailed test case info if the test fails.
                f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
            )
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    def run_stateful_custom_logit_processor(
        self, first_token_id: int | None, delay: int = 2
    ):
        """Test custom logit processor with custom params and state.

        Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
        If first_token_id is None, the custom logit processor won't be passed in.
        """
        custom_params = {"token_id": first_token_id, "delay": 2}

        class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
            """A dummy logit processor that changes the logits to always
            sample the given token id.
            """

            def __call__(self, logits, custom_param_list):
                assert logits.shape[0] == len(custom_param_list)

                for i, param_dict in enumerate(custom_param_list):
                    if param_dict["delay"] > 0:
                        param_dict["delay"] -= 1
                        continue
                    if param_dict["delay"] == 0:
                        param_dict["delay"] -= 1
                        force_token = param_dict["token_id"]
                    else:
                        output_ids = param_dict["__req__"].output_ids
                        force_token = output_ids[-1] + 1
                    # Mask all other tokens
                    logits[i, :] = -float("inf")
                    # Assign highest probability to the specified token
                    logits[i, force_token] = 0.0
                return logits

        prompts = "Question: Is Paris the Capital of France? Answer:"

        # Base case json data to be posted to the server.
        base_json = {
            "text": prompts,
            "sampling_params": {"temperature": 0.0},
            "return_logprob": True,
        }

        # Custom json data with custom logit processor and params.
        custom_json = base_json.copy()
        # Only set the custom logit processor if target_token_id is not None.
        if first_token_id is not None:
            custom_json["custom_logit_processor"] = (
                DeterministicStatefulLogitProcessor().to_str()
            )
            custom_json["sampling_params"]["custom_params"] = custom_params

        custom_response = requests.post(
            self.base_url + "/generate",
            json=custom_json,
        ).json()

        output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
        sampled_tokens = [x[1] for x in output_token_logprobs]
        # The logit processor should always sample the given token as the logits is deterministic.
        if first_token_id is not None:
            self.assertTrue(
                all(
                    x == custom_params["token_id"] + k
                    for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
                ),
                # Print the detailed test case info if the test fails.
                f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
            )

437
438
439
440
    def test_custom_logit_processor(self):
        """Test custom logit processor with a single request."""
        self.run_custom_logit_processor(target_token_id=5)

441
442
443
444
445
446
447
    def test_custom_logit_processor_batch_mixed(self):
        """Test a batch of requests mixed of requests with and without custom logit processor."""
        target_token_ids = list(range(32)) + [None] * 16
        random.shuffle(target_token_ids)
        with ThreadPoolExecutor(len(target_token_ids)) as executor:
            list(executor.map(self.run_custom_logit_processor, target_token_ids))

448
    @unittest.skip("Skip this test because this feature has a bug. See comments below.")
449
450
    def test_stateful_custom_logit_processor(self):
        """Test custom logit processor with a single request."""
451
452
453
454
455
456
457
458
459
460

        """
        NOTE: This feature has a race condition bug.
        This line https://github.com/sgl-project/sglang/blob/ef8ec07b2ce4c70c2a33ec5acda4ce529bc3cda4/test/srt/test_srt_endpoint.py#L395-L396 can be accessed by two concurrent threads at the same time. The access order is not guaranteed.
        In sglang, we use two python threads to overlap the GPU computation and CPU scheduling.
        Thread 1 (the CPU scheduling thread) will update the `param_dict["__req__"].output_ids`.
        Thread 2 (the GPU computation thread) will call `DeterministicStatefulLogitProcessor` because sampling is considered as GPU computation.
        We can fix this by moving the call of DeterministicStatefulLogitProcessor to the CPU scheduling thread.
        """

461
462
        self.run_stateful_custom_logit_processor(first_token_id=5)

463
    @unittest.skip("Skip this test because this feature has a bug. See comments above.")
464
465
466
467
468
469
470
471
472
    def test_stateful_custom_logit_processor_batch_mixed(self):
        """Test a batch of requests mixed of requests with and without custom logit processor."""
        target_token_ids = list(range(32)) + [None] * 16
        random.shuffle(target_token_ids)
        with ThreadPoolExecutor(len(target_token_ids)) as executor:
            list(
                executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
            )

473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def test_cache_tokens(self):
        for _ in range(2):
            time.sleep(1)
            response = requests.post(self.base_url + "/flush_cache")
            assert response.status_code == 200

        def send_and_check_cached_tokens(input_ids):
            response = requests.post(
                self.base_url + "/generate",
                json={
                    "input_ids": list(input_ids),
                    "sampling_params": {
                        "max_new_tokens": 1,
                    },
                },
            )
            response_json = response.json()
            return response_json["meta_info"]["cached_tokens"]

        self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
        self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
        self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
        self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
        self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)

498
499
500
501
502
503
504
    def test_get_server_info(self):
        response = requests.get(self.base_url + "/get_server_info")
        response_json = response.json()

        max_total_num_tokens = response_json["max_total_num_tokens"]
        self.assertIsInstance(max_total_num_tokens, int)

505
506
507
        version = response_json["version"]
        self.assertIsInstance(version, str)

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
    def test_logit_bias(self):
        """Test that a very high logit bias forces sampling of a specific token."""
        # Choose a token ID to bias (using 5 as an example)
        target_token_id = 60704  # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        logit_bias = {str(target_token_id): 100.0}  # Very high positive bias

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
                "sampling_params": {
                    "temperature": 1.0,  # Use high temperature to encourage exploration
                    "max_new_tokens": 4,
                    "logit_bias": logit_bias,
                },
                "return_logprob": True,
            },
        )
        response_json = response.json()

        # Extract the sampled token IDs from the output
        output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
        sampled_tokens = [x[1] for x in output_token_logprobs]

        # Verify that all sampled tokens are the target token
        self.assertTrue(
            all(x == target_token_id for x in sampled_tokens),
            f"Expected all tokens to be {target_token_id}, but got {sampled_tokens}",
        )

    def test_forbidden_token(self):
        """Test that a forbidden token (very negative logit bias) doesn't appear in the output."""
        # Choose a token ID to forbid (using 10 as an example)
        forbidden_token_id = 23994  # rice for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        logit_bias = {
            str(forbidden_token_id): -100.0
        }  # Very negative bias to forbid the token

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "Only output 'rice' exactly like this, in lowercase ONLY: rice",
                "sampling_params": {
                    "temperature": 1.0,  # Use high temperature to encourage diverse output
                    "max_new_tokens": 50,  # Generate enough tokens to likely include numbers
                    "logit_bias": logit_bias,
                },
                "return_logprob": True,
            },
        )
        response_json = response.json()

        # Extract the sampled token IDs from the output
        output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
        sampled_tokens = [x[1] for x in output_token_logprobs]

        # Verify that the forbidden token doesn't appear in the output
        self.assertNotIn(
            forbidden_token_id,
            sampled_tokens,
            f"Expected forbidden token {forbidden_token_id} not to be present, but it was found",
        )

    def test_logit_bias_isolation(self):
        """Test that logit_bias applied to one request doesn't affect other requests in batch."""
        # Choose a token ID to bias in first request only
        biased_token_id = 60704  # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        # Prepare batch requests - one with logit_bias and one without
        requests_data = [
            {
                "text": "The capital of France is",
                "sampling_params": {
                    "temperature": 1.0,
                    "max_new_tokens": 4,
                    "logit_bias": {str(biased_token_id): 100.0},  # Strong bias
                },
                "return_logprob": True,
            },
            {
                "text": "The capital of France is",
                "sampling_params": {
                    "temperature": 1.0,
                    "max_new_tokens": 4,
                },
                "return_logprob": True,
            },
        ]

        # Send both requests
        responses = []
        for req in requests_data:
            response = requests.post(self.base_url + "/generate", json=req)
            responses.append(response.json())

        # Extract token IDs from each response
        biased_tokens = [
            x[1] for x in responses[0]["meta_info"]["output_token_logprobs"]
        ]
        unbiased_tokens = [
            x[1] for x in responses[1]["meta_info"]["output_token_logprobs"]
        ]

        # Verify first response contains only biased tokens
        self.assertTrue(
            all(x == biased_token_id for x in biased_tokens),
            f"Expected all tokens to be {biased_token_id} in first response, but got {biased_tokens}",
        )

        # Verify second response contains at least some different tokens
        # (We can't guarantee exactly what tokens will be generated, but they shouldn't all be the biased token)
        self.assertTrue(
            any(x != biased_token_id for x in unbiased_tokens),
            f"Expected some tokens to be different from {biased_token_id} in second response, but got {unbiased_tokens}",
        )

624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    def test_get_server_info_concurrent(self):
        """Make sure the concurrent get_server_info doesn't crash the server."""
        tp = ThreadPoolExecutor(max_workers=30)

        def s():
            server_info = requests.get(self.base_url + "/get_server_info")
            server_info.json()

        futures = []
        for _ in range(4):
            futures.append(tp.submit(s))

        for f in futures:
            f.result()

639

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
# -------------------------------------------------------------------------
#    /tokenize & /detokenize Test Class: TestTokenizeDetokenize
# -------------------------------------------------------------------------


class TestTokenizeDetokenize(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.tokenize_url = f"{cls.base_url}/tokenize"
        cls.detokenize_url = f"{cls.base_url}/detokenize"
        cls.session = requests.Session()
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)
        cls.session.close()

    def _post_json(self, url, payload):
        r = self.session.post(url, json=payload)
        r.raise_for_status()
        return r.json()

    def test_tokenize_various_inputs(self):
        single = "Hello SGLang world! 123 😊, ಪರ್ವತದ ಮೇಲೆ ಹಿಮ."
        multi = ["First sentence.", "Second, with 中文."]
        scenarios = [
            {"prompt": single, "add_special_tokens": True},
            {"prompt": single, "add_special_tokens": False},
            {"prompt": multi, "add_special_tokens": True},
            {"prompt": multi, "add_special_tokens": False},
            {"prompt": "", "add_special_tokens": False},
        ]
        for case in scenarios:
            payload = {"model": self.model, "prompt": case["prompt"]}
            if "add_special_tokens" in case:
                payload["add_special_tokens"] = case["add_special_tokens"]
            resp = self._post_json(self.tokenize_url, payload)
            tokens = resp["tokens"]
            count = resp["count"]
            self.assertIsInstance(tokens, list)
            if not tokens:
                self.assertEqual(count, 0)
            else:
                if isinstance(tokens[0], list):
                    total = sum(len(t) for t in tokens)
                    expected = sum(count) if isinstance(count, list) else count
                else:
                    total = len(tokens)
                    expected = count
                self.assertEqual(total, expected)

    def test_tokenize_invalid_type(self):
        r = self.session.post(
            self.tokenize_url, json={"model": self.model, "prompt": 12345}
        )
        self.assertEqual(r.status_code, 400)

    def test_detokenize_roundtrip(self):
        text = "Verify detokenization round trip. यह डिटोकेनाइजेशन है"
        t0 = self._post_json(
            self.tokenize_url,
            {"model": self.model, "prompt": text, "add_special_tokens": False},
        )["tokens"]
        t1 = self._post_json(
            self.tokenize_url,
            {"model": self.model, "prompt": text, "add_special_tokens": True},
        )["tokens"]
        cases = [
            {"tokens": t0, "skip_special_tokens": True, "expected": text},
            {"tokens": t1, "skip_special_tokens": True, "expected": text},
            {"tokens": t1, "skip_special_tokens": False, "expected": None},
            {"tokens": [], "skip_special_tokens": True, "expected": ""},
        ]
        for case in cases:
            payload = {"model": self.model, "tokens": case["tokens"]}
            if "skip_special_tokens" in case:
                payload["skip_special_tokens"] = case["skip_special_tokens"]
            resp = self._post_json(self.detokenize_url, payload)
            text_out = resp["text"]
            if case["expected"] is not None:
                self.assertEqual(text_out, case["expected"])
            else:
                self.assertIsInstance(text_out, str)

    def test_detokenize_invalid_tokens(self):
        r = self.session.post(
            self.detokenize_url, json={"model": self.model, "tokens": ["a", "b"]}
        )
        self.assertEqual(r.status_code, 400)
        r2 = self.session.post(
            self.detokenize_url, json={"model": self.model, "tokens": [1, -1, 2]}
        )
        self.assertEqual(r2.status_code, 500)


742
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
743
    unittest.main()