test_openai_server.py 19.7 KB
Newer Older
1
2
3
4
5
"""
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion

"""
Chayenne's avatar
Chayenne committed
6

7
import json
8
import time
9
import unittest
10
11

import openai
12

yichuan~'s avatar
yichuan~ committed
13
from sglang.srt.hf_transformers_utils import get_tokenizer
14
from sglang.srt.utils import kill_child_process
15
16
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
17
18
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
19
20
    popen_launch_server,
)
21
22
23
24
25


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
26
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
27
        cls.base_url = DEFAULT_URL_FOR_TEST
28
29
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
30
31
32
33
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
34
        )
35
        cls.base_url += "/v1"
Ying Sheng's avatar
Ying Sheng committed
36
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
37
38
39

    @classmethod
    def tearDownClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
40
        kill_child_process(cls.process.pid, include_self=True)
41

yichuan~'s avatar
yichuan~ committed
42
43
44
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
45
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
46
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
47
48
49
50
51
52
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
53
54

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
55
            prompt_arg = [prompt_input, prompt_input]
56
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
57
            num_prompt_tokens *= 2
58
        else:
yichuan~'s avatar
yichuan~ committed
59
            prompt_arg = prompt_input
60
61
            num_choices = 1

62
63
        response = client.completions.create(
            model=self.model,
64
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
65
            temperature=0,
66
67
68
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
69
            n=parallel_sample_num,
70
        )
71

yichuan~'s avatar
yichuan~ committed
72
        assert len(response.choices) == num_choices * parallel_sample_num
73

Cody Yu's avatar
Cody Yu committed
74
        if echo:
75
            text = response.choices[0].text
76
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
77

Cody Yu's avatar
Cody Yu committed
78
        if logprobs:
79
80
81
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
82
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
83

84
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
85
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
86
            assert ret_num_top_logprobs > 0
87
88

            assert response.choices[0].logprobs.token_logprobs[0]
yichuan~'s avatar
yichuan~ committed
89

90
91
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
92
93
94
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
95
96
97
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

98
99
100
    def run_completion_stream(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
101
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
102
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
103
        if token_input:
104
105
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
yichuan~'s avatar
yichuan~ committed
106
        else:
107
108
109
110
111
112
113
114
115
116
117
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_choices = len(prompt_arg)
            num_prompt_tokens *= 2
        else:
            prompt_arg = prompt_input
            num_choices = 1

118
119
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
120
121
            prompt=prompt_arg,
            temperature=0,
122
123
124
125
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
126
            stream_options={"include_usage": True},
127
            n=parallel_sample_num,
128
129
        )

130
        is_firsts = {}
131
        for response in generator:
132
133
134
135
136
137
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
138
139
140
141

            index = response.choices[0].index
            is_first = is_firsts.get(index, True)

142
143
144
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
145
                if not (is_first and echo):
146
147
148
149
150
151
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
152
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
153
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
154
                    assert ret_num_top_logprobs > 0
155

156
            if is_first:
157
                if echo:
yichuan~'s avatar
yichuan~ committed
158
159
                    assert response.choices[0].text.startswith(
                        prompt
160
161
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
                is_firsts[index] = False
162
163
164
            assert response.id
            assert response.created

165
166
167
168
169
        for index in [i for i in range(parallel_sample_num * num_choices)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

yichuan~'s avatar
yichuan~ committed
170
    def run_chat_completion(self, logprobs, parallel_sample_num):
171
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
172
173
174
175
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
176
177
178
179
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
180
181
182
183
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
184
            n=parallel_sample_num,
185
        )
Ying Sheng's avatar
Ying Sheng committed
186

187
188
189
190
191
192
193
194
195
196
197
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
198

yichuan~'s avatar
yichuan~ committed
199
        assert len(response.choices) == parallel_sample_num
200
201
202
203
204
205
206
207
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

208
    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
209
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
210
211
212
213
214
215
216
217
218
219
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
220
            stream_options={"include_usage": True},
221
            n=parallel_sample_num,
222
223
        )

224
        is_firsts = {}
225
        for response in generator:
226
227
228
229
230
231
232
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

233
            index = response.choices[0].index
234
            data = response.choices[0].delta
235

236
237
238
            if is_firsts.get(index, True):
                assert data.role == "assistant"
                is_firsts[index] = False
239
240
241
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
242
243
244
245
246
247
248
249
250
251
252
253
254
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
255
256
257
258
259

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

260
261
262
263
264
        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

265
    def _create_batch(self, mode, client):
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
341

342
343
344
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
345

346
347
348
349
350
351
352
353
354
355
356
357
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
358

359
        return batch_job, content, uploaded_file
360
361
362

    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
363
        batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
364

365
366
367
368
369
370
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
371
372
373
        assert (
            batch_job.status == "completed"
        ), f"Batch job status is not completed: {batch_job.status}"
374
375
376
377
378
379
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
380
381
382
383
384
385
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
386
        assert len(results) == len(content)
387
388
389
        for delete_fid in [uploaded_file.id, result_file_id]:
            del_pesponse = client.files.delete(delete_fid)
            assert del_pesponse.deleted
390

391
392
    def run_cancel_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
393
        batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
394
395
396
397
398
399
400
401
402
403
404
405
406
407

        assert batch_job.status not in ["cancelling", "cancelled"]

        batch_job = client.batches.cancel(batch_id=batch_job.id)
        assert batch_job.status == "cancelling"

        while batch_job.status not in ["failed", "cancelled"]:
            batch_job = client.batches.retrieve(batch_job.id)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            time.sleep(3)

        assert batch_job.status == "cancelled"
408
409
        del_response = client.files.delete(uploaded_file.id)
        assert del_response.deleted
410

411
412
413
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
414
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
415
416
417
418
419
420
421
422
423
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
424
425

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
426
        # parallel sampling adn list input are not supported in streaming mode
427
428
        for echo in [False, True]:
            for logprobs in [None, 5]:
429
430
431
432
433
434
435
436
437
438
                for use_list_input in [True, False]:
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion_stream(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
439

440
441
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
442
443
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
444
445
446

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
447
448
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)
449

450
451
452
453
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

454
    def test_cancel_batch(self):
455
456
457
        for mode in ["completion", "chat"]:
            self.run_cancel_batch(mode)

458
    def test_regex(self):
459
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    def test_penalty(self):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=32,
            frequency_penalty=1.0,
        )
        text = response.choices[0].message.content
        assert isinstance(text, str)

504
505
506
507
    def test_response_prefill(self):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)

        response = client.chat.completions.create(
508
            model="meta-llama/Llama-3.1-8B-Instruct",
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {
                    "role": "user",
                    "content": """
Extract the name, size, price, and color from this product description as a JSON object:

<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
                },
                {
                    "role": "assistant",
                    "content": "{\n",
                },
            ],
            temperature=0,
        )

        assert (
            response.choices[0]
            .message.content.strip()
            .startswith('"name": "SmartHome Mini",')
        )

535

536
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
537
    unittest.main()