test_openai_server.py 16.4 KB
Newer Older
1
import json
2
import time
3
import unittest
4
5

import openai
6

yichuan~'s avatar
yichuan~ committed
7
from sglang.srt.hf_transformers_utils import get_tokenizer
8
from sglang.srt.utils import kill_child_process
9
10
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
Yineng Zhang's avatar
Yineng Zhang committed
11
    DEFAULT_URL_FOR_UNIT_TEST,
12
13
    popen_launch_server,
)
14
15
16
17
18


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
19
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
Yineng Zhang's avatar
Yineng Zhang committed
20
        cls.base_url = DEFAULT_URL_FOR_UNIT_TEST
21
22
23
24
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=300, api_key=cls.api_key
        )
25
        cls.base_url += "/v1"
Ying Sheng's avatar
Ying Sheng committed
26
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
27
28
29
30
31

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

yichuan~'s avatar
yichuan~ committed
32
33
34
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
35
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
36
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
37
38
39
40
41
42
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
43
44

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
45
            prompt_arg = [prompt_input, prompt_input]
46
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
47
            num_prompt_tokens *= 2
48
        else:
yichuan~'s avatar
yichuan~ committed
49
            prompt_arg = prompt_input
50
51
            num_choices = 1

52
53
        response = client.completions.create(
            model=self.model,
54
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
55
            temperature=0,
56
57
58
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
59
            n=parallel_sample_num,
60
        )
61

yichuan~'s avatar
yichuan~ committed
62
        assert len(response.choices) == num_choices * parallel_sample_num
63

Cody Yu's avatar
Cody Yu committed
64
        if echo:
65
            text = response.choices[0].text
66
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
67

Cody Yu's avatar
Cody Yu committed
68
        if logprobs:
69
70
71
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
72
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
73

yichuan~'s avatar
yichuan~ committed
74
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
75
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
76

yichuan~'s avatar
yichuan~ committed
77
            assert ret_num_top_logprobs > 0
78
            assert response.choices[0].logprobs.token_logprobs[0] != None
yichuan~'s avatar
yichuan~ committed
79

80
81
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
82
83
84
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
85
86
87
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

88
89
90
    def run_completion_stream(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
91
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
92
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
93
        if token_input:
94
95
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
yichuan~'s avatar
yichuan~ committed
96
        else:
97
98
99
100
101
102
103
104
105
106
107
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_choices = len(prompt_arg)
            num_prompt_tokens *= 2
        else:
            prompt_arg = prompt_input
            num_choices = 1

108
109
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
110
111
            prompt=prompt_arg,
            temperature=0,
112
113
114
115
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
116
            stream_options={"include_usage": True},
117
            n=parallel_sample_num,
118
119
        )

120
        is_firsts = {}
121
        for response in generator:
122
123
124
125
126
127
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
128
129
130
131

            index = response.choices[0].index
            is_first = is_firsts.get(index, True)

132
133
134
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
135
                if not (is_first and echo):
136
137
138
139
140
141
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
yichuan~'s avatar
yichuan~ committed
142
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
143
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
144
                    assert ret_num_top_logprobs > 0
145

146
            if is_first:
147
                if echo:
yichuan~'s avatar
yichuan~ committed
148
149
                    assert response.choices[0].text.startswith(
                        prompt
150
151
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
                is_firsts[index] = False
152
153
154
            assert response.id
            assert response.created

155
156
157
158
159
        for index in [i for i in range(parallel_sample_num * num_choices)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

yichuan~'s avatar
yichuan~ committed
160
    def run_chat_completion(self, logprobs, parallel_sample_num):
161
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
162
163
164
165
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
166
167
168
169
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
170
171
172
173
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
174
            n=parallel_sample_num,
175
        )
Ying Sheng's avatar
Ying Sheng committed
176

177
178
179
180
181
182
183
184
185
186
187
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
188

yichuan~'s avatar
yichuan~ committed
189
        assert len(response.choices) == parallel_sample_num
190
191
192
193
194
195
196
197
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

198
    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
199
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
200
201
202
203
204
205
206
207
208
209
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
210
            stream_options={"include_usage": True},
211
            n=parallel_sample_num,
212
213
        )

214
        is_firsts = {}
215
        for response in generator:
216
217
218
219
220
221
222
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

223
            index = response.choices[0].index
224
            data = response.choices[0].delta
225

226
227
228
            if is_firsts.get(index, True):
                assert data.role == "assistant"
                is_firsts[index] = False
229
230
231
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
232
233
234
235
236
237
238
239
240
241
242
243
244
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
245
246
247
248
249

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

250
251
252
253
254
        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
353
354
355
        assert (
            batch_job.status == "completed"
        ), f"Batch job status is not completed: {batch_job.status}"
356
357
358
359
360
361
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
362
363
364
365
366
367
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
368
369
        assert len(results) == len(content)

370
371
372
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
373
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
374
375
376
377
378
379
380
381
382
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
383
384

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
385
        # parallel sampling adn list input are not supported in streaming mode
386
387
        for echo in [False, True]:
            for logprobs in [None, 5]:
388
389
390
391
392
393
394
395
396
397
                for use_list_input in [True, False]:
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion_stream(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
398

399
400
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
401
402
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
403
404
405

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
406
407
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)
408

409
410
411
412
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

413
    def test_regex(self):
414
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

443

444
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
445
    unittest.main()