test_srt_endpoint.py 4.04 KB
Newer Older
1
2
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
3
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample
4
5
"""

6
7
8
9
10
11
import json
import unittest

import requests

from sglang.srt.utils import kill_child_process
12
13
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
14
15
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
16
17
    popen_launch_server,
)
18
19
20
21
22


class TestSRTEndpoint(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
23
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
24
25
26
27
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
        )
28
29
30

    @classmethod
    def tearDownClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
31
        kill_child_process(cls.process.pid, include_self=True)
32
33

    def run_decode(
34
35
36
37
38
39
        self,
        return_logprob=False,
        top_logprobs_num=0,
        return_text=False,
        n=1,
        stream=False,
40
        batch=False,
41
    ):
42
43
44
45
46
        if batch:
            text = ["The capital of France is"]
        else:
            text = "The capital of France is"

47
48
49
        response = requests.post(
            self.base_url + "/generate",
            json={
50
                "text": text,
51
52
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
53
                    "max_new_tokens": 16,
54
55
                    "n": n,
                },
56
                "stream": stream,
57
58
59
60
61
62
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "return_text_in_logprobs": return_text,
                "logprob_start_len": 0,
            },
        )
63
64
65
66
67
68
69
        if not stream:
            response_json = response.json()
        else:
            response_json = []
            for line in response.iter_lines():
                if line.startswith(b"data: ") and line[6:] != b"[DONE]":
                    response_json.append(json.loads(line[6:]))
70
71

        print(json.dumps(response_json, indent=2))
72
73
74
75
76
        print("=" * 100)

    def test_simple_decode(self):
        self.run_decode()

77
78
79
    def test_simple_decode_batch(self):
        self.run_decode(batch=True)

80
81
82
    def test_parallel_sample(self):
        self.run_decode(n=3)

83
84
85
    def test_parallel_sample_stream(self):
        self.run_decode(n=3, stream=True)

86
    def test_logprob(self):
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        self.run_decode(
            return_logprob=True,
            top_logprobs_num=5,
            return_text=True,
        )

    def test_logprob_start_len(self):
        logprob_start_len = 4
        new_tokens = 4
        prompts = [
            "I have a very good idea on",
            "Today is a sunndy day and",
        ]

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": prompts,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": new_tokens,
                },
                "return_logprob": True,
                "top_logprobs_num": 5,
                "return_text_in_logprobs": True,
                "logprob_start_len": logprob_start_len,
            },
        )
        response_json = response.json()
        print(json.dumps(response_json, indent=2))

        for i, res in enumerate(response_json):
            assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
                res["meta_info"]["input_token_logprobs"]
            )
            assert prompts[i].endswith(
                "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
            )

            assert res["meta_info"]["completion_tokens"] == new_tokens
            assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
            res["text"] == "".join(
                [x[-1] for x in res["meta_info"]["output_token_logprobs"]]
            )
131

132
133
134
135
    def test_get_memory_pool_size(self):
        response = requests.post(self.base_url + "/get_memory_pool_size")
        assert isinstance(response.json(), int)

136
137

if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
138
    unittest.main()