test_srt_endpoint.py 5.25 KB
Newer Older
1
2
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
3
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
4
5
"""

6
7
8
9
10
11
import json
import unittest

import requests

from sglang.srt.utils import kill_child_process
12
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
13
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
14
15
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
16
17
    popen_launch_server,
)
18
19
20
21
22


class TestSRTEndpoint(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
23
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
24
25
26
27
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
        )
28
29
30

    @classmethod
    def tearDownClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
31
        kill_child_process(cls.process.pid, include_self=True)
32
33

    def run_decode(
34
35
36
37
38
39
        self,
        return_logprob=False,
        top_logprobs_num=0,
        return_text=False,
        n=1,
        stream=False,
40
        batch=False,
41
    ):
42
43
44
45
46
        if batch:
            text = ["The capital of France is"]
        else:
            text = "The capital of France is"

47
48
49
        response = requests.post(
            self.base_url + "/generate",
            json={
50
                "text": text,
51
52
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
53
                    "max_new_tokens": 16,
54
55
                    "n": n,
                },
56
                "stream": stream,
57
58
59
60
61
62
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "return_text_in_logprobs": return_text,
                "logprob_start_len": 0,
            },
        )
63
64
65
66
67
68
69
        if not stream:
            response_json = response.json()
        else:
            response_json = []
            for line in response.iter_lines():
                if line.startswith(b"data: ") and line[6:] != b"[DONE]":
                    response_json.append(json.loads(line[6:]))
70
71

        print(json.dumps(response_json, indent=2))
72
73
74
75
76
        print("=" * 100)

    def test_simple_decode(self):
        self.run_decode()

77
78
79
    def test_simple_decode_batch(self):
        self.run_decode(batch=True)

80
81
82
    def test_parallel_sample(self):
        self.run_decode(n=3)

83
84
85
    def test_parallel_sample_stream(self):
        self.run_decode(n=3, stream=True)

86
    def test_logprob(self):
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.run_decode(
            return_logprob=True,
            top_logprobs_num=5,
            return_text=True,
        )

    def test_logprob_start_len(self):
        logprob_start_len = 4
        new_tokens = 4
        prompts = [
            "I have a very good idea on",
            "Today is a sunndy day and",
        ]

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": prompts,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": new_tokens,
                },
                "return_logprob": True,
                "top_logprobs_num": 5,
                "return_text_in_logprobs": True,
                "logprob_start_len": logprob_start_len,
            },
        )
        response_json = response.json()
        print(json.dumps(response_json, indent=2))

        for i, res in enumerate(response_json):
119
120
121
            self.assertEqual(
                res["meta_info"]["prompt_tokens"],
                logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
122
123
124
125
126
            )
            assert prompts[i].endswith(
                "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
            )

127
128
129
130
131
            self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
            self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
            self.assertEqual(
                res["text"],
                "".join([x[-1] for x in res["meta_info"]["output_token_logprobs"]]),
132
            )
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def test_logprob_with_chunked_prefill(self):
        new_tokens = 4
        prompts = "I have a very good idea on this. " * 8000

        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": prompts,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": new_tokens,
                },
                "return_logprob": True,
                "logprob_start_len": -1,
            },
        )
        response_json = response.json()
        print(json.dumps(response_json, indent=2))

        res = response_json
        self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
        self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)

157
158
159
160
161
162
163
164
165
166
167
168
    def test_get_server_info(self):
        response = requests.get(self.base_url + "/get_server_info")
        response_json = response.json()

        max_total_num_tokens = response_json["max_total_num_tokens"]
        self.assertIsInstance(max_total_num_tokens, int)

        memory_pool_size = response_json["memory_pool_size"]
        self.assertIsInstance(memory_pool_size, int)

        attention_backend = response_json["attention_backend"]
        self.assertIsInstance(attention_backend, str)
169

170
171

if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
172
    unittest.main()