test_openai_server.py 3.97 KB
Newer Older
1
2
3
import subprocess
import time
import unittest
4
5

import openai
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import requests

from sglang.srt.utils import kill_child_process


class TestOpenAIServer(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
        port = 30000
        timeout = 300

        command = [
            "python3", "-m", "sglang.launch_server",
            "--model-path", model,
            "--host", "localhost",
            "--port", str(port),
        ]
        cls.process = subprocess.Popen(command, stdout=None, stderr=None)
        cls.base_url = f"http://localhost:{port}/v1"
        cls.model = model

        start_time = time.time()
        while time.time() - start_time < timeout:
            try:
                response = requests.get(f"{cls.base_url}/models")
                if response.status_code == 200:
                    return
            except requests.RequestException:
                pass
            time.sleep(10)
        raise TimeoutError("Server failed to start within the timeout period.")

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

    def run_completion(self, echo, logprobs):
        client = openai.Client(api_key="EMPTY", base_url=self.base_url)
        prompt = "The capital of France is"
        response = client.completions.create(
            model=self.model,
            prompt=prompt,
            temperature=0.1,
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
        )
        text = response.choices[0].text
Cody Yu's avatar
Cody Yu committed
56
        if echo:
57
            assert text.startswith(prompt)
Cody Yu's avatar
Cody Yu committed
58
        if logprobs:
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            assert response.choices[0].logprobs
            assert isinstance(response.choices[0].logprobs.tokens[0], str)
            assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
            assert len(response.choices[0].logprobs.top_logprobs[1]) == logprobs
            if echo:
                assert response.choices[0].logprobs.token_logprobs[0] == None
            else:
                assert response.choices[0].logprobs.token_logprobs[0] != None
        assert response.id
        assert response.created
        assert response.usage.prompt_tokens > 0
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

    def run_completion_stream(self, echo, logprobs):
        client = openai.Client(api_key="EMPTY", base_url=self.base_url)
        prompt = "The capital of France is"
        generator = client.completions.create(
            model=self.model,
            prompt=prompt,
            temperature=0.1,
            max_tokens=32,
            echo=echo,
            logprobs=logprobs,
            stream=True,
        )

        first = True
        for response in generator:
            if logprobs:
                assert response.choices[0].logprobs
                assert isinstance(response.choices[0].logprobs.tokens[0], str)
                if not (first and echo):
                    assert isinstance(response.choices[0].logprobs.top_logprobs[0], dict)
                    #assert len(response.choices[0].logprobs.top_logprobs[0]) == logprobs

            if first:
                if echo:
                    assert response.choices[0].text.startswith(prompt)
                first = False

            assert response.id
            assert response.created
            assert response.usage.prompt_tokens > 0
            assert response.usage.completion_tokens > 0
            assert response.usage.total_tokens > 0

    def test_completion(self):
        for echo in [False, True]:
            for logprobs in [None, 5]:
                self.run_completion(echo, logprobs)

    def test_completion_stream(self):
        for echo in [True]:
            for logprobs in [5]:
                self.run_completion_stream(echo, logprobs)
115
116


117
if __name__ == "__main__":
118
119
120
121
122
123
    # unittest.main(warnings="ignore")

    t = TestOpenAIServer()
    t.setUpClass()
    t.test_completion_stream()
    t.tearDownClass()