test_openai_server.py 17.5 KB
Newer Older
1
import json
2
import time
3
import unittest
4
5

import openai
6

yichuan~'s avatar
yichuan~ committed
7
from sglang.srt.hf_transformers_utils import get_tokenizer
8
from sglang.srt.utils import kill_child_process
9
10
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
11
12
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
13
14
    popen_launch_server,
)
15
16
17
18
19


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
20
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
21
        cls.base_url = DEFAULT_URL_FOR_TEST
22
23
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
24
25
26
27
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
28
        )
29
        cls.base_url += "/v1"
Ying Sheng's avatar
Ying Sheng committed
30
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
31
32
33
34
35

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

yichuan~'s avatar
yichuan~ committed
36
37
38
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
39
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
40
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
41
42
43
44
45
46
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
47
48

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
49
            prompt_arg = [prompt_input, prompt_input]
50
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
51
            num_prompt_tokens *= 2
52
        else:
yichuan~'s avatar
yichuan~ committed
53
            prompt_arg = prompt_input
54
55
            num_choices = 1

56
57
        response = client.completions.create(
            model=self.model,
58
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
59
            temperature=0,
60
61
62
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
63
            n=parallel_sample_num,
64
        )
65

yichuan~'s avatar
yichuan~ committed
66
        assert len(response.choices) == num_choices * parallel_sample_num
67

Cody Yu's avatar
Cody Yu committed
68
        if echo:
69
            text = response.choices[0].text
70
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
71

Cody Yu's avatar
Cody Yu committed
72
        if logprobs:
73
74
75
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
76
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
77

yichuan~'s avatar
yichuan~ committed
78
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
79
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
80

yichuan~'s avatar
yichuan~ committed
81
            assert ret_num_top_logprobs > 0
82
            assert response.choices[0].logprobs.token_logprobs[0] != None
yichuan~'s avatar
yichuan~ committed
83

84
85
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
86
87
88
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
89
90
91
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

92
93
94
    def run_completion_stream(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
95
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
96
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
97
        if token_input:
98
99
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
yichuan~'s avatar
yichuan~ committed
100
        else:
101
102
103
104
105
106
107
108
109
110
111
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_choices = len(prompt_arg)
            num_prompt_tokens *= 2
        else:
            prompt_arg = prompt_input
            num_choices = 1

112
113
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
114
115
            prompt=prompt_arg,
            temperature=0,
116
117
118
119
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
120
            stream_options={"include_usage": True},
121
            n=parallel_sample_num,
122
123
        )

124
        is_firsts = {}
125
        for response in generator:
126
127
128
129
130
131
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
132
133
134
135

            index = response.choices[0].index
            is_first = is_firsts.get(index, True)

136
137
138
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
139
                if not (is_first and echo):
140
141
142
143
144
145
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
yichuan~'s avatar
yichuan~ committed
146
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
147
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
148
                    assert ret_num_top_logprobs > 0
149

150
            if is_first:
151
                if echo:
yichuan~'s avatar
yichuan~ committed
152
153
                    assert response.choices[0].text.startswith(
                        prompt
154
155
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
                is_firsts[index] = False
156
157
158
            assert response.id
            assert response.created

159
160
161
162
163
        for index in [i for i in range(parallel_sample_num * num_choices)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

yichuan~'s avatar
yichuan~ committed
164
    def run_chat_completion(self, logprobs, parallel_sample_num):
165
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
166
167
168
169
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
170
171
172
173
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
174
175
176
177
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
178
            n=parallel_sample_num,
179
        )
Ying Sheng's avatar
Ying Sheng committed
180

181
182
183
184
185
186
187
188
189
190
191
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
192

yichuan~'s avatar
yichuan~ committed
193
        assert len(response.choices) == parallel_sample_num
194
195
196
197
198
199
200
201
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

202
    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
203
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
204
205
206
207
208
209
210
211
212
213
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
214
            stream_options={"include_usage": True},
215
            n=parallel_sample_num,
216
217
        )

218
        is_firsts = {}
219
        for response in generator:
220
221
222
223
224
225
226
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

227
            index = response.choices[0].index
228
            data = response.choices[0].delta
229

230
231
232
            if is_firsts.get(index, True):
                assert data.role == "assistant"
                is_firsts[index] = False
233
234
235
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
236
237
238
239
240
241
242
243
244
245
246
247
248
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
249
250
251
252
253

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

254
255
256
257
258
        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

259
    def _create_batch(self, mode, client):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
335

336
337
338
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
339

340
341
342
343
344
345
346
347
348
349
350
351
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
352
353
354
355
356
357
358

        return batch_job, content

    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        batch_job, content = self._create_batch(mode=mode, client=client)

359
360
361
362
363
364
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
365
366
367
        assert (
            batch_job.status == "completed"
        ), f"Batch job status is not completed: {batch_job.status}"
368
369
370
371
372
373
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
374
375
376
377
378
379
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
380
381
        assert len(results) == len(content)

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    def run_cancel_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        batch_job, _ = self._create_batch(mode=mode, client=client)

        assert batch_job.status not in ["cancelling", "cancelled"]

        batch_job = client.batches.cancel(batch_id=batch_job.id)
        assert batch_job.status == "cancelling"

        while batch_job.status not in ["failed", "cancelled"]:
            batch_job = client.batches.retrieve(batch_job.id)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            time.sleep(3)

        assert batch_job.status == "cancelled"

400
401
402
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
403
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
404
405
406
407
408
409
410
411
412
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
413
414

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
415
        # parallel sampling adn list input are not supported in streaming mode
416
417
        for echo in [False, True]:
            for logprobs in [None, 5]:
418
419
420
421
422
423
424
425
426
427
                for use_list_input in [True, False]:
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion_stream(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
428

429
430
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
431
432
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
433
434
435

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
436
437
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)
438

439
440
441
442
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

443
444
445
446
    def test_calcel_batch(self):
        for mode in ["completion", "chat"]:
            self.run_cancel_batch(mode)

447
    def test_regex(self):
448
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

477

478
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
479
    unittest.main()