"research/textsum/seq2seq_attention_model.py" did not exist on "66eb101896a1b7c1691c7416668776ea60ddb78f"
test_openai_server.py 18.3 KB
Newer Older
1
import json
2
import time
3
import unittest
4
5

import openai
6

yichuan~'s avatar
yichuan~ committed
7
from sglang.srt.hf_transformers_utils import get_tokenizer
8
from sglang.srt.utils import kill_child_process
9
10
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
11
12
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
13
14
    popen_launch_server,
)
15
16
17
18
19


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
20
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
21
        cls.base_url = DEFAULT_URL_FOR_TEST
22
23
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
24
25
26
27
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
28
        )
29
        cls.base_url += "/v1"
Ying Sheng's avatar
Ying Sheng committed
30
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
31
32
33
34
35

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

yichuan~'s avatar
yichuan~ committed
36
37
38
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
39
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
40
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
41
42
43
44
45
46
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
47
48

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
49
            prompt_arg = [prompt_input, prompt_input]
50
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
51
            num_prompt_tokens *= 2
52
        else:
yichuan~'s avatar
yichuan~ committed
53
            prompt_arg = prompt_input
54
55
            num_choices = 1

56
57
        response = client.completions.create(
            model=self.model,
58
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
59
            temperature=0,
60
61
62
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
63
            n=parallel_sample_num,
64
        )
65

yichuan~'s avatar
yichuan~ committed
66
        assert len(response.choices) == num_choices * parallel_sample_num
67

Cody Yu's avatar
Cody Yu committed
68
        if echo:
69
            text = response.choices[0].text
70
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
71

Cody Yu's avatar
Cody Yu committed
72
        if logprobs:
73
74
75
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
76
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
77

78
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
79
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
80
            assert ret_num_top_logprobs > 0
81
82

            assert response.choices[0].logprobs.token_logprobs[0]
yichuan~'s avatar
yichuan~ committed
83

84
85
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
86
87
88
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
89
90
91
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

92
93
94
    def run_completion_stream(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
95
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
96
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
97
        if token_input:
98
99
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
yichuan~'s avatar
yichuan~ committed
100
        else:
101
102
103
104
105
106
107
108
109
110
111
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_choices = len(prompt_arg)
            num_prompt_tokens *= 2
        else:
            prompt_arg = prompt_input
            num_choices = 1

112
113
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
114
115
            prompt=prompt_arg,
            temperature=0,
116
117
118
119
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
120
            stream_options={"include_usage": True},
121
            n=parallel_sample_num,
122
123
        )

124
        is_firsts = {}
125
        for response in generator:
126
127
128
129
130
131
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
132
133
134
135

            index = response.choices[0].index
            is_first = is_firsts.get(index, True)

136
137
138
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
139
                if not (is_first and echo):
140
141
142
143
144
145
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
146
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
147
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
148
                    assert ret_num_top_logprobs > 0
149

150
            if is_first:
151
                if echo:
yichuan~'s avatar
yichuan~ committed
152
153
                    assert response.choices[0].text.startswith(
                        prompt
154
155
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
                is_firsts[index] = False
156
157
158
            assert response.id
            assert response.created

159
160
161
162
163
        for index in [i for i in range(parallel_sample_num * num_choices)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

yichuan~'s avatar
yichuan~ committed
164
    def run_chat_completion(self, logprobs, parallel_sample_num):
165
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
166
167
168
169
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
170
171
172
173
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
174
175
176
177
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
178
            n=parallel_sample_num,
179
        )
Ying Sheng's avatar
Ying Sheng committed
180

181
182
183
184
185
186
187
188
189
190
191
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
192

yichuan~'s avatar
yichuan~ committed
193
        assert len(response.choices) == parallel_sample_num
194
195
196
197
198
199
200
201
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

202
    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
203
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
204
205
206
207
208
209
210
211
212
213
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
214
            stream_options={"include_usage": True},
215
            n=parallel_sample_num,
216
217
        )

218
        is_firsts = {}
219
        for response in generator:
220
221
222
223
224
225
226
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

227
            index = response.choices[0].index
228
            data = response.choices[0].delta
229

230
231
232
            if is_firsts.get(index, True):
                assert data.role == "assistant"
                is_firsts[index] = False
233
234
235
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
236
237
238
239
240
241
242
243
244
245
246
247
248
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
249
250
251
252
253

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

254
255
256
257
258
        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

259
    def _create_batch(self, mode, client):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
335

336
337
338
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
339

340
341
342
343
344
345
346
347
348
349
350
351
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
352

353
        return batch_job, content, uploaded_file
354
355
356

    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
357
        batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
358

359
360
361
362
363
364
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
365
366
367
        assert (
            batch_job.status == "completed"
        ), f"Batch job status is not completed: {batch_job.status}"
368
369
370
371
372
373
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
374
375
376
377
378
379
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
380
        assert len(results) == len(content)
381
382
383
        for delete_fid in [uploaded_file.id, result_file_id]:
            del_pesponse = client.files.delete(delete_fid)
            assert del_pesponse.deleted
384

385
386
    def run_cancel_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
387
        batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
388
389
390
391
392
393
394
395
396
397
398
399
400
401

        assert batch_job.status not in ["cancelling", "cancelled"]

        batch_job = client.batches.cancel(batch_id=batch_job.id)
        assert batch_job.status == "cancelling"

        while batch_job.status not in ["failed", "cancelled"]:
            batch_job = client.batches.retrieve(batch_job.id)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            time.sleep(3)

        assert batch_job.status == "cancelled"
402
403
        del_response = client.files.delete(uploaded_file.id)
        assert del_response.deleted
404

405
406
407
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
408
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
409
410
411
412
413
414
415
416
417
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
418
419

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
420
        # parallel sampling adn list input are not supported in streaming mode
421
422
        for echo in [False, True]:
            for logprobs in [None, 5]:
423
424
425
426
427
428
429
430
431
432
                for use_list_input in [True, False]:
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion_stream(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
433

434
435
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
436
437
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
438
439
440

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
441
442
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)
443

444
445
446
447
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

448
449
450
451
    def test_calcel_batch(self):
        for mode in ["completion", "chat"]:
            self.run_cancel_batch(mode)

452
    def test_regex(self):
453
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def test_penalty(self):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=32,
            frequency_penalty=1.0,
        )
        text = response.choices[0].message.content
        assert isinstance(text, str)

498

499
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
500
    unittest.main()