test_openai_server.py 16.5 KB
Newer Older
1
import json
2
import time
3
import unittest
4
5

import openai
6

yichuan~'s avatar
yichuan~ committed
7
from sglang.srt.hf_transformers_utils import get_tokenizer
8
from sglang.srt.utils import kill_child_process
9
10
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
11
12
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
13
14
    popen_launch_server,
)
15
16
17
18
19


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
20
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
21
        cls.base_url = DEFAULT_URL_FOR_TEST
22
23
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
24
25
26
27
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
28
        )
29
        cls.base_url += "/v1"
Ying Sheng's avatar
Ying Sheng committed
30
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
31
32
33
34
35

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

yichuan~'s avatar
yichuan~ committed
36
37
38
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
39
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
40
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
41
42
43
44
45
46
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
47
48

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
49
            prompt_arg = [prompt_input, prompt_input]
50
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
51
            num_prompt_tokens *= 2
52
        else:
yichuan~'s avatar
yichuan~ committed
53
            prompt_arg = prompt_input
54
55
            num_choices = 1

56
57
        response = client.completions.create(
            model=self.model,
58
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
59
            temperature=0,
60
61
62
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
63
            n=parallel_sample_num,
64
        )
65

yichuan~'s avatar
yichuan~ committed
66
        assert len(response.choices) == num_choices * parallel_sample_num
67

Cody Yu's avatar
Cody Yu committed
68
        if echo:
69
            text = response.choices[0].text
70
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
71

Cody Yu's avatar
Cody Yu committed
72
        if logprobs:
73
74
75
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
76
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
77

yichuan~'s avatar
yichuan~ committed
78
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
79
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
80

yichuan~'s avatar
yichuan~ committed
81
            assert ret_num_top_logprobs > 0
82
            assert response.choices[0].logprobs.token_logprobs[0] != None
yichuan~'s avatar
yichuan~ committed
83

84
85
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
86
87
88
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
89
90
91
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

92
93
94
    def run_completion_stream(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
95
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
96
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
97
        if token_input:
98
99
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
yichuan~'s avatar
yichuan~ committed
100
        else:
101
102
103
104
105
106
107
108
109
110
111
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_choices = len(prompt_arg)
            num_prompt_tokens *= 2
        else:
            prompt_arg = prompt_input
            num_choices = 1

112
113
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
114
115
            prompt=prompt_arg,
            temperature=0,
116
117
118
119
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
120
            stream_options={"include_usage": True},
121
            n=parallel_sample_num,
122
123
        )

124
        is_firsts = {}
125
        for response in generator:
126
127
128
129
130
131
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
132
133
134
135

            index = response.choices[0].index
            is_first = is_firsts.get(index, True)

136
137
138
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
139
                if not (is_first and echo):
140
141
142
143
144
145
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
yichuan~'s avatar
yichuan~ committed
146
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
147
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
148
                    assert ret_num_top_logprobs > 0
149

150
            if is_first:
151
                if echo:
yichuan~'s avatar
yichuan~ committed
152
153
                    assert response.choices[0].text.startswith(
                        prompt
154
155
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
                is_firsts[index] = False
156
157
158
            assert response.id
            assert response.created

159
160
161
162
163
        for index in [i for i in range(parallel_sample_num * num_choices)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

yichuan~'s avatar
yichuan~ committed
164
    def run_chat_completion(self, logprobs, parallel_sample_num):
165
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
166
167
168
169
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
170
171
172
173
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
174
175
176
177
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
178
            n=parallel_sample_num,
179
        )
Ying Sheng's avatar
Ying Sheng committed
180

181
182
183
184
185
186
187
188
189
190
191
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
192

yichuan~'s avatar
yichuan~ committed
193
        assert len(response.choices) == parallel_sample_num
194
195
196
197
198
199
200
201
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

202
    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
203
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
204
205
206
207
208
209
210
211
212
213
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
214
            stream_options={"include_usage": True},
215
            n=parallel_sample_num,
216
217
        )

218
        is_firsts = {}
219
        for response in generator:
220
221
222
223
224
225
226
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

227
            index = response.choices[0].index
228
            data = response.choices[0].delta
229

230
231
232
            if is_firsts.get(index, True):
                assert data.role == "assistant"
                is_firsts[index] = False
233
234
235
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
236
237
238
239
240
241
242
243
244
245
246
247
248
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
249
250
251
252
253

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

254
255
256
257
258
        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
357
358
359
        assert (
            batch_job.status == "completed"
        ), f"Batch job status is not completed: {batch_job.status}"
360
361
362
363
364
365
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
366
367
368
369
370
371
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
372
373
        assert len(results) == len(content)

374
375
376
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
377
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
378
379
380
381
382
383
384
385
386
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
387
388

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
389
        # parallel sampling adn list input are not supported in streaming mode
390
391
        for echo in [False, True]:
            for logprobs in [None, 5]:
392
393
394
395
396
397
398
399
400
401
                for use_list_input in [True, False]:
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion_stream(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
402

403
404
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
405
406
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
407
408
409

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
410
411
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)
412

413
414
415
416
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

417
    def test_regex(self):
418
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

447

448
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
449
    unittest.main()