test_embedding_openai_server.py 2.67 KB
Newer Older
Ying Sheng's avatar
Ying Sheng committed
1
2
3
4
5
import unittest

import openai

from sglang.srt.hf_transformers_utils import get_tokenizer
6
from sglang.srt.utils import kill_process_tree
7
8
9
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
10
    CustomTestCase,
11
12
    popen_launch_server,
)
Ying Sheng's avatar
Ying Sheng committed
13
14


15
class TestOpenAIServer(CustomTestCase):
Ying Sheng's avatar
Ying Sheng committed
16
17
18
    @classmethod
    def setUpClass(cls):
        cls.model = "intfloat/e5-mistral-7b-instruct"
19
        cls.base_url = DEFAULT_URL_FOR_TEST
Ying Sheng's avatar
Ying Sheng committed
20
21
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
22
23
24
25
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
Ying Sheng's avatar
Ying Sheng committed
26
27
28
29
30
31
        )
        cls.base_url += "/v1"
        cls.tokenizer = get_tokenizer(cls.model)

    @classmethod
    def tearDownClass(cls):
32
        kill_process_tree(cls.process.pid)
Ying Sheng's avatar
Ying Sheng committed
33
34
35
36
37
38
39
40
41
42
43
44

    def run_embedding(self, use_list_input, token_input):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        prompt = "The capital of France is"
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
45
            prompt_arg = [prompt_input] * 2
Ying Sheng's avatar
Ying Sheng committed
46
            num_prompts = len(prompt_arg)
47
            num_prompt_tokens *= num_prompts
Ying Sheng's avatar
Ying Sheng committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        else:
            prompt_arg = prompt_input
            num_prompts = 1

        response = client.embeddings.create(
            input=prompt_arg,
            model=self.model,
        )

        assert len(response.data) == num_prompts
        assert isinstance(response.data, list)
        assert response.data[0].embedding
        assert response.data[0].index is not None
        assert response.data[0].object == "embedding"
        assert response.model == self.model
        assert response.object == "list"
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
        assert (
            response.usage.total_tokens == num_prompt_tokens
        ), f"{response.usage.total_tokens} vs {num_prompt_tokens}"

    def run_batch(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
72
        # FIXME: not implemented
Ying Sheng's avatar
Ying Sheng committed
73
74
75
        pass

    def test_embedding(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
        # TODO: the fields of encoding_format, dimensions, user are skipped
        # TODO: support use_list_input
78
        for use_list_input in [False, True]:
Ying Sheng's avatar
Ying Sheng committed
79
80
81
82
83
84
85
86
            for token_input in [False, True]:
                self.run_embedding(use_list_input, token_input)

    def test_batch(self):
        self.run_batch()


if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
87
    unittest.main()