test_openai_server.py 19.8 KB
Newer Older
1
2
3
4
5
"""
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion

"""
Chayenne's avatar
Chayenne committed
6

7
import json
8
import time
9
import unittest
10
11

import openai
12

yichuan~'s avatar
yichuan~ committed
13
from sglang.srt.hf_transformers_utils import get_tokenizer
14
from sglang.srt.utils import kill_child_process
15
from sglang.test.test_utils import (
Lianmin Zheng's avatar
Lianmin Zheng committed
16
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
17
18
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
19
20
    popen_launch_server,
)
21
22
23
24
25


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
26
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
27
        cls.base_url = DEFAULT_URL_FOR_TEST
28
29
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
30
31
32
33
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
34
        )
35
        cls.base_url += "/v1"
Lianmin Zheng's avatar
Lianmin Zheng committed
36
        cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
37
38
39

    @classmethod
    def tearDownClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
40
        kill_child_process(cls.process.pid, include_self=True)
41

yichuan~'s avatar
yichuan~ committed
42
43
44
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
45
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
46
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
47
48
49
50
51
52
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
53
54

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
55
            prompt_arg = [prompt_input, prompt_input]
56
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
57
            num_prompt_tokens *= 2
58
        else:
yichuan~'s avatar
yichuan~ committed
59
            prompt_arg = prompt_input
60
61
            num_choices = 1

62
63
        response = client.completions.create(
            model=self.model,
64
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
65
            temperature=0,
66
67
68
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
69
            n=parallel_sample_num,
70
        )
71

yichuan~'s avatar
yichuan~ committed
72
        assert len(response.choices) == num_choices * parallel_sample_num
73

Cody Yu's avatar
Cody Yu committed
74
        if echo:
75
            text = response.choices[0].text
76
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
77

Cody Yu's avatar
Cody Yu committed
78
        if logprobs:
79
80
81
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
82
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
83

84
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
85
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
86
            assert ret_num_top_logprobs > 0
87

88
89
90
            # when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
            if not echo:
                assert response.choices[0].logprobs.token_logprobs[0]
yichuan~'s avatar
yichuan~ committed
91

92
93
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
94
95
96
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
97
98
99
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

100
101
102
    def run_completion_stream(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
103
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
104
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
105
        if token_input:
106
107
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
yichuan~'s avatar
yichuan~ committed
108
        else:
109
110
111
112
113
114
115
116
117
118
119
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_choices = len(prompt_arg)
            num_prompt_tokens *= 2
        else:
            prompt_arg = prompt_input
            num_choices = 1

120
121
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
122
123
            prompt=prompt_arg,
            temperature=0,
124
125
126
127
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
128
            stream_options={"include_usage": True},
129
            n=parallel_sample_num,
130
131
        )

132
        is_firsts = {}
133
        for response in generator:
134
135
136
137
138
139
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
140
141
142
143

            index = response.choices[0].index
            is_first = is_firsts.get(index, True)

144
145
146
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
147
                if not (is_first and echo):
148
149
150
151
152
153
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
154
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
155
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
156
                    assert ret_num_top_logprobs > 0
157

158
            if is_first:
159
                if echo:
yichuan~'s avatar
yichuan~ committed
160
161
                    assert response.choices[0].text.startswith(
                        prompt
162
163
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
                is_firsts[index] = False
164
165
166
            assert response.id
            assert response.created

167
168
169
170
171
        for index in [i for i in range(parallel_sample_num * num_choices)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

yichuan~'s avatar
yichuan~ committed
172
    def run_chat_completion(self, logprobs, parallel_sample_num):
173
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
174
175
176
177
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
178
179
180
181
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
182
183
184
185
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
186
            n=parallel_sample_num,
187
        )
Ying Sheng's avatar
Ying Sheng committed
188

189
190
191
192
193
194
195
196
197
198
199
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
200

yichuan~'s avatar
yichuan~ committed
201
        assert len(response.choices) == parallel_sample_num
202
203
204
205
206
207
208
209
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

210
    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
211
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
212
213
214
215
216
217
218
219
220
221
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
222
            stream_options={"include_usage": True},
223
            n=parallel_sample_num,
224
225
        )

226
        is_firsts = {}
227
        for response in generator:
228
229
230
231
232
233
234
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

235
            index = response.choices[0].index
236
            data = response.choices[0].delta
237

238
239
240
            if is_firsts.get(index, True):
                assert data.role == "assistant"
                is_firsts[index] = False
241
242
243
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
244
245
246
247
248
249
250
251
252
253
254
255
256
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
257
258
259
260
261

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

262
263
264
265
266
        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

267
    def _create_batch(self, mode, client):
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
343

344
345
346
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
347

348
349
350
351
352
353
354
355
356
357
358
359
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
360

361
        return batch_job, content, uploaded_file
362
363
364

    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
365
        batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
366

367
368
369
370
371
372
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
373
374
375
        assert (
            batch_job.status == "completed"
        ), f"Batch job status is not completed: {batch_job.status}"
376
377
378
379
380
381
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
382
383
384
385
386
387
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
388
        assert len(results) == len(content)
389
390
391
        for delete_fid in [uploaded_file.id, result_file_id]:
            del_pesponse = client.files.delete(delete_fid)
            assert del_pesponse.deleted
392

393
394
    def run_cancel_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
395
        batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
396
397
398
399
400
401
402
403
404
405
406
407
408
409

        assert batch_job.status not in ["cancelling", "cancelled"]

        batch_job = client.batches.cancel(batch_id=batch_job.id)
        assert batch_job.status == "cancelling"

        while batch_job.status not in ["failed", "cancelled"]:
            batch_job = client.batches.retrieve(batch_job.id)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            time.sleep(3)

        assert batch_job.status == "cancelled"
410
411
        del_response = client.files.delete(uploaded_file.id)
        assert del_response.deleted
412

413
414
415
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
416
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
417
418
419
420
421
422
423
424
425
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
426
427

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
428
        # parallel sampling adn list input are not supported in streaming mode
429
430
        for echo in [False, True]:
            for logprobs in [None, 5]:
431
432
433
434
435
436
437
438
439
440
                for use_list_input in [True, False]:
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion_stream(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
441

442
443
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
444
445
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
446
447
448

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
449
450
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)
451

452
453
454
455
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

456
    def test_cancel_batch(self):
457
458
459
        for mode in ["completion", "chat"]:
            self.run_cancel_batch(mode)

460
    def test_regex(self):
461
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    def test_penalty(self):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=32,
            frequency_penalty=1.0,
        )
        text = response.choices[0].message.content
        assert isinstance(text, str)

506
507
508
509
    def test_response_prefill(self):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)

        response = client.chat.completions.create(
510
            model="meta-llama/Llama-3.1-8B-Instruct",
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {
                    "role": "user",
                    "content": """
Extract the name, size, price, and color from this product description as a JSON object:

<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
                },
                {
                    "role": "assistant",
                    "content": "{\n",
                },
            ],
            temperature=0,
        )

        assert (
            response.choices[0]
            .message.content.strip()
            .startswith('"name": "SmartHome Mini",')
        )

537

538
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
539
    unittest.main()