test_openai_server.py 10.3 KB
Newer Older
1
import json
2
import unittest
3
4

import openai
5

yichuan~'s avatar
yichuan~ committed
6
from sglang.srt.hf_transformers_utils import get_tokenizer
7
from sglang.srt.utils import kill_child_process
8
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
9
10
11
12
13
14


class TestOpenAIServer(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
15
        cls.model = MODEL_NAME_FOR_TEST
16
        cls.base_url = f"http://localhost:8157"
17
18
19
20
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=300, api_key=cls.api_key
        )
21
        cls.base_url += "/v1"
yichuan~'s avatar
yichuan~ committed
22
        cls.tokenizer = get_tokenizer(MODEL_NAME_FOR_TEST)
23
24
25
26
27

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

yichuan~'s avatar
yichuan~ committed
28
29
30
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
31
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
32
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
33
34
35
36
37
38
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
39
40

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
41
            prompt_arg = [prompt_input, prompt_input]
42
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
43
            num_prompt_tokens *= 2
44
        else:
yichuan~'s avatar
yichuan~ committed
45
            prompt_arg = prompt_input
46
47
            num_choices = 1

yichuan~'s avatar
yichuan~ committed
48
49
50
51
52
        if parallel_sample_num:
            # FIXME: This is wrong. We should not count the prompt tokens multiple times for
            # parallel sampling.
            num_prompt_tokens *= parallel_sample_num

53
54
        response = client.completions.create(
            model=self.model,
55
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
56
            temperature=0,
57
58
59
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
60
            n=parallel_sample_num,
61
        )
62

yichuan~'s avatar
yichuan~ committed
63
        assert len(response.choices) == num_choices * parallel_sample_num
64

Cody Yu's avatar
Cody Yu committed
65
        if echo:
66
            text = response.choices[0].text
67
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
68

Cody Yu's avatar
Cody Yu committed
69
        if logprobs:
70
71
72
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
73
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
yichuan~'s avatar
yichuan~ committed
74
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
75
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
76
            assert ret_num_top_logprobs > 0
77
78
79
80
            if echo:
                assert response.choices[0].logprobs.token_logprobs[0] == None
            else:
                assert response.choices[0].logprobs.token_logprobs[0] != None
yichuan~'s avatar
yichuan~ committed
81

82
83
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
84
85
86
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
87
88
89
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

yichuan~'s avatar
yichuan~ committed
90
    def run_completion_stream(self, echo, logprobs, token_input):
91
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
92
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
93
94
95
96
        if token_input:
            prompt_arg = self.tokenizer.encode(prompt)
        else:
            prompt_arg = prompt
97
98
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
99
100
            prompt=prompt_arg,
            temperature=0,
101
102
103
104
105
106
107
108
109
110
111
112
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
        )

        first = True
        for response in generator:
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
                if not (first and echo):
113
114
115
116
117
118
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
yichuan~'s avatar
yichuan~ committed
119
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
120
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
121
                    assert ret_num_top_logprobs > 0
122
123
124

            if first:
                if echo:
yichuan~'s avatar
yichuan~ committed
125
126
127
                    assert response.choices[0].text.startswith(
                        prompt
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
128
129
130
131
132
133
134
135
                first = False

            assert response.id
            assert response.created
            assert response.usage.prompt_tokens > 0
            assert response.usage.completion_tokens > 0
            assert response.usage.total_tokens > 0

yichuan~'s avatar
yichuan~ committed
136
    def run_chat_completion(self, logprobs, parallel_sample_num):
137
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
138
139
140
141
142
143
144
145
146
147
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            max_tokens=32,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
148
            n=parallel_sample_num,
149
150
151
152
153
154
155
156
157
158
159
160
        )
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
161
        assert len(response.choices) == parallel_sample_num
162
163
164
165
166
167
168
169
170
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

    def run_chat_completion_stream(self, logprobs):
171
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            max_tokens=32,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
        )

        is_first = True
        for response in generator:
            data = response.choices[0].delta
            if is_first:
                data.role == "assistant"
                is_first = False
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
194
195
196
197
198
199
200
201
202
203
204
205
206
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
207
208
209
210
211

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

212
213
214
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
215
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
216
217
218
219
220
221
222
223
224
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
225
226

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
227
        # parallel sampling adn list input are not supported in streaming mode
228
229
        for echo in [False, True]:
            for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
230
231
                for token_input in [False, True]:
                    self.run_completion_stream(echo, logprobs, token_input)
232

233
234
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
235
236
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
237
238
239
240
241
242

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
            self.run_chat_completion_stream(logprobs)

    def test_regex(self):
243
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

272

273
if __name__ == "__main__":
274
    unittest.main(warnings="ignore")
275

276
277
    # t = TestOpenAIServer()
    # t.setUpClass()
yichuan~'s avatar
yichuan~ committed
278
    # t.test_completion()
279
    # t.tearDownClass()