test_large_max_new_tokens.py 2.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
import os
import time
import unittest
from concurrent.futures import ThreadPoolExecutor

import openai

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server


class TestOpenAIServer(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
        cls.base_url = "http://127.0.0.1:8157"
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=300,
            api_key=cls.api_key,
            other_args=("--max-total-token", "1024"),
            env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ},
            return_stdout_stderr=True,
        )
        cls.base_url += "/v1"
        cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

    def run_chat_completion(self):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {
                    "role": "user",
                    "content": "Please repeat the world 'hello' for 10000 times.",
                },
            ],
            temperature=0,
        )
        return response

    def test_chat_completion(self):
        num_requests = 4

        futures = []
        with ThreadPoolExecutor(16) as executor:
            for i in range(num_requests):
                futures.append(executor.submit(self.run_chat_completion))

            all_requests_running = False
            for line in iter(self.process.stderr.readline, ""):
                line = str(line)
                print(line, end="")
                if f"#running-req: {num_requests}" in line:
                    all_requests_running = True
                    break

        assert all_requests_running


if __name__ == "__main__":
    unittest.main()