test_embedding_openai_server.py 2.68 KB
Newer Older
Ying Sheng's avatar
Ying Sheng committed
1
2
3
4
5
6
import unittest

import openai

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
7
8
9
10
11
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)
Ying Sheng's avatar
Ying Sheng committed
12
13
14
15
16
17


class TestOpenAIServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = "intfloat/e5-mistral-7b-instruct"
18
        cls.base_url = DEFAULT_URL_FOR_TEST
Ying Sheng's avatar
Ying Sheng committed
19
20
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
21
22
23
24
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
Ying Sheng's avatar
Ying Sheng committed
25
26
27
28
29
30
        )
        cls.base_url += "/v1"
        cls.tokenizer = get_tokenizer(cls.model)

    @classmethod
    def tearDownClass(cls):
Lianmin Zheng's avatar
Lianmin Zheng committed
31
        kill_child_process(cls.process.pid, include_self=True)
Ying Sheng's avatar
Ying Sheng committed
32
33
34
35
36
37
38
39
40
41
42
43

    def run_embedding(self, use_list_input, token_input):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        prompt = "The capital of France is"
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
44
            prompt_arg = [prompt_input] * 2
Ying Sheng's avatar
Ying Sheng committed
45
            num_prompts = len(prompt_arg)
46
            num_prompt_tokens *= num_prompts
Ying Sheng's avatar
Ying Sheng committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        else:
            prompt_arg = prompt_input
            num_prompts = 1

        response = client.embeddings.create(
            input=prompt_arg,
            model=self.model,
        )

        assert len(response.data) == num_prompts
        assert isinstance(response.data, list)
        assert response.data[0].embedding
        assert response.data[0].index is not None
        assert response.data[0].object == "embedding"
        assert response.model == self.model
        assert response.object == "list"
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
        assert (
            response.usage.total_tokens == num_prompt_tokens
        ), f"{response.usage.total_tokens} vs {num_prompt_tokens}"

    def run_batch(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
71
        # FIXME: not implemented
Ying Sheng's avatar
Ying Sheng committed
72
73
74
        pass

    def test_embedding(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
        # TODO: the fields of encoding_format, dimensions, user are skipped
        # TODO: support use_list_input
77
        for use_list_input in [False, True]:
Ying Sheng's avatar
Ying Sheng committed
78
79
80
81
82
83
84
85
            for token_input in [False, True]:
                self.run_embedding(use_list_input, token_input)

    def test_batch(self):
        self.run_batch()


if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    unittest.main()