"ts/nni_manager/common/trainingService.ts" did not exist on "c2a4ce6cf7cbbf462529531fa59a557d3ed58aa9"
test_openai_server.py 15 KB
Newer Older
1
import json
2
import time
3
import unittest
4
5

import openai
6

yichuan~'s avatar
yichuan~ committed
7
from sglang.srt.hf_transformers_utils import get_tokenizer
8
from sglang.srt.utils import kill_child_process
9
10
11
12
13
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)
14
15
16
17
18
19


class TestOpenAIServer(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
Ying Sheng's avatar
Ying Sheng committed
20
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
21
        cls.base_url = DEFAULT_URL_FOR_TEST
22
23
24
25
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=300, api_key=cls.api_key
        )
26
        cls.base_url += "/v1"
Ying Sheng's avatar
Ying Sheng committed
27
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
28
29
30
31
32

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

yichuan~'s avatar
yichuan~ committed
33
34
35
    def run_completion(
        self, echo, logprobs, use_list_input, parallel_sample_num, token_input
    ):
36
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
37
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
38
39
40
41
42
43
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))
44
45

        if use_list_input:
yichuan~'s avatar
yichuan~ committed
46
            prompt_arg = [prompt_input, prompt_input]
47
            num_choices = len(prompt_arg)
yichuan~'s avatar
yichuan~ committed
48
            num_prompt_tokens *= 2
49
        else:
yichuan~'s avatar
yichuan~ committed
50
            prompt_arg = prompt_input
51
52
            num_choices = 1

53
54
        response = client.completions.create(
            model=self.model,
55
            prompt=prompt_arg,
yichuan~'s avatar
yichuan~ committed
56
            temperature=0,
57
58
59
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
60
            n=parallel_sample_num,
61
        )
62

yichuan~'s avatar
yichuan~ committed
63
        assert len(response.choices) == num_choices * parallel_sample_num
64

Cody Yu's avatar
Cody Yu committed
65
        if echo:
66
            text = response.choices[0].text
67
            assert text.startswith(prompt)
yichuan~'s avatar
yichuan~ committed
68

Cody Yu's avatar
Cody Yu committed
69
        if logprobs:
70
71
72
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
73
            ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
yichuan~'s avatar
yichuan~ committed
74
            # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
75
            # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
76
            assert ret_num_top_logprobs > 0
77
78
79
80
            if echo:
                assert response.choices[0].logprobs.token_logprobs[0] == None
            else:
                assert response.choices[0].logprobs.token_logprobs[0] != None
yichuan~'s avatar
yichuan~ committed
81

82
83
        assert response.id
        assert response.created
yichuan~'s avatar
yichuan~ committed
84
85
86
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
87
88
89
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

yichuan~'s avatar
yichuan~ committed
90
    def run_completion_stream(self, echo, logprobs, token_input):
91
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
92
        prompt = "The capital of France is"
yichuan~'s avatar
yichuan~ committed
93
94
95
96
        if token_input:
            prompt_arg = self.tokenizer.encode(prompt)
        else:
            prompt_arg = prompt
97
98
        generator = client.completions.create(
            model=self.model,
yichuan~'s avatar
yichuan~ committed
99
100
            prompt=prompt_arg,
            temperature=0,
101
102
103
104
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
105
            stream_options={"include_usage": True},
106
107
108
109
        )

        first = True
        for response in generator:
110
111
112
113
114
115
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue
116
117
118
119
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
                if not (first and echo):
120
121
122
123
124
125
                    assert isinstance(
                        response.choices[0].logprobs.top_logprobs[0], dict
                    )
                    ret_num_top_logprobs = len(
                        response.choices[0].logprobs.top_logprobs[0]
                    )
yichuan~'s avatar
yichuan~ committed
126
                    # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
127
                    # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
yichuan~'s avatar
yichuan~ committed
128
                    assert ret_num_top_logprobs > 0
129
130
131

            if first:
                if echo:
yichuan~'s avatar
yichuan~ committed
132
133
134
                    assert response.choices[0].text.startswith(
                        prompt
                    ), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
135
136
137
138
                first = False
            assert response.id
            assert response.created

yichuan~'s avatar
yichuan~ committed
139
    def run_chat_completion(self, logprobs, parallel_sample_num):
140
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
141
142
143
144
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
Ying Sheng's avatar
Ying Sheng committed
145
146
147
148
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
149
150
151
152
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
yichuan~'s avatar
yichuan~ committed
153
            n=parallel_sample_num,
154
        )
Ying Sheng's avatar
Ying Sheng committed
155

156
157
158
159
160
161
162
163
164
165
166
        if logprobs:
            assert isinstance(
                response.choices[0].logprobs.content[0].top_logprobs[0].token, str
            )

            ret_num_top_logprobs = len(
                response.choices[0].logprobs.content[0].top_logprobs
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"
Ying Sheng's avatar
Ying Sheng committed
167

yichuan~'s avatar
yichuan~ committed
168
        assert len(response.choices) == parallel_sample_num
169
170
171
172
173
174
175
176
177
        assert response.choices[0].message.role == "assistant"
        assert isinstance(response.choices[0].message.content, str)
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

    def run_chat_completion_stream(self, logprobs):
178
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
179
180
181
182
183
184
185
186
187
188
        generator = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "What is the capital of France?"},
            ],
            temperature=0,
            logprobs=logprobs is not None and logprobs > 0,
            top_logprobs=logprobs,
            stream=True,
189
            stream_options={"include_usage": True},
190
191
192
193
        )

        is_first = True
        for response in generator:
194
195
196
197
198
199
200
            usage = response.usage
            if usage is not None:
                assert usage.prompt_tokens > 0
                assert usage.completion_tokens > 0
                assert usage.total_tokens > 0
                continue

201
            data = response.choices[0].delta
202

203
204
205
206
207
208
            if is_first:
                data.role == "assistant"
                is_first = False
                continue

            if logprobs:
yichuan~'s avatar
yichuan~ committed
209
210
211
212
213
214
215
216
217
218
219
220
221
                assert response.choices[0].logprobs
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs[0].token, str
                )
                assert isinstance(
                    response.choices[0].logprobs.content[0].top_logprobs, list
                )
                ret_num_top_logprobs = len(
                    response.choices[0].logprobs.content[0].top_logprobs
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"
222
223
224
225
226

            assert isinstance(data.content, str)
            assert response.id
            assert response.created

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def run_batch(self, mode):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        if mode == "completion":
            input_file_path = "complete_input.jsonl"
            # write content to input file
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 3 names of famous soccer player: ",
                        "max_tokens": 20,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous basketball player:  ",
                        "max_tokens": 40,
                    },
                },
                {
                    "custom_id": "request-3",
                    "method": "POST",
                    "url": "/v1/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-instruct",
                        "prompt": "List 6 names of famous tenniss player:  ",
                        "max_tokens": 40,
                    },
                },
            ]

        else:
            input_file_path = "chat_input.jsonl"
            content = [
                {
                    "custom_id": "request-1",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant.",
                            },
                            {
                                "role": "user",
                                "content": "Hello! List 3 NBA players and tell a story",
                            },
                        ],
                        "max_tokens": 30,
                    },
                },
                {
                    "custom_id": "request-2",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": "gpt-3.5-turbo-0125",
                        "messages": [
                            {"role": "system", "content": "You are an assistant. "},
                            {
                                "role": "user",
                                "content": "Hello! List three capital and tell a story",
                            },
                        ],
                        "max_tokens": 50,
                    },
                },
            ]
        with open(input_file_path, "w") as file:
            for line in content:
                file.write(json.dumps(line) + "\n")
        with open(input_file_path, "rb") as file:
            uploaded_file = client.files.create(file=file, purpose="batch")
        if mode == "completion":
            endpoint = "/v1/completions"
        elif mode == "chat":
            endpoint = "/v1/chat/completions"
        completion_window = "24h"
        batch_job = client.batches.create(
            input_file_id=uploaded_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
        while batch_job.status not in ["completed", "failed", "cancelled"]:
            time.sleep(3)
            print(
                f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
            )
            batch_job = client.batches.retrieve(batch_job.id)
        assert batch_job.status == "completed"
        assert batch_job.request_counts.completed == len(content)
        assert batch_job.request_counts.failed == 0
        assert batch_job.request_counts.total == len(content)

        result_file_id = batch_job.output_file_id
        file_response = client.files.content(result_file_id)
yichuan~'s avatar
yichuan~ committed
332
333
334
335
336
337
        result_content = file_response.read().decode("utf-8")  # Decode bytes to string
        results = [
            json.loads(line)
            for line in result_content.split("\n")
            if line.strip() != ""
        ]
338
339
        assert len(results) == len(content)

340
341
342
    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
343
                for use_list_input in [True, False]:
yichuan~'s avatar
yichuan~ committed
344
345
346
347
348
349
350
351
352
                    for parallel_sample_num in [1, 2]:
                        for token_input in [False, True]:
                            self.run_completion(
                                echo,
                                logprobs,
                                use_list_input,
                                parallel_sample_num,
                                token_input,
                            )
353
354

    def test_completion_stream(self):
yichuan~'s avatar
yichuan~ committed
355
        # parallel sampling adn list input are not supported in streaming mode
356
357
        for echo in [False, True]:
            for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
358
359
                for token_input in [False, True]:
                    self.run_completion_stream(echo, logprobs, token_input)
360

361
362
    def test_chat_completion(self):
        for logprobs in [None, 5]:
yichuan~'s avatar
yichuan~ committed
363
364
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)
365
366
367
368
369

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
            self.run_chat_completion_stream(logprobs)

370
371
372
373
    def test_batch(self):
        for mode in ["completion", "chat"]:
            self.run_batch(mode)

374
    def test_regex(self):
375
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403

        regex = (
            r"""\{\n"""
            + r"""   "name": "[\w]+",\n"""
            + r"""   "population": [\d]+\n"""
            + r"""\}"""
        )

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
            extra_body={"regex": regex},
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
        assert isinstance(js_obj["name"], str)
        assert isinstance(js_obj["population"], int)

404

405
if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
406
    unittest.main()