test_embedding_openai_server.py 2.68 KB
Newer Older
Ying Sheng's avatar
Ying Sheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import time
import unittest

import openai

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import EmbeddingObject
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import popen_launch_server


class TestOpenAIServer(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.model = "intfloat/e5-mistral-7b-instruct"
        cls.base_url = "http://127.0.0.1:8157"
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
            cls.model, cls.base_url, timeout=300, api_key=cls.api_key
        )
        cls.base_url += "/v1"
        cls.tokenizer = get_tokenizer(cls.model)

    @classmethod
    def tearDownClass(cls):
        kill_child_process(cls.process.pid)

    def run_embedding(self, use_list_input, token_input):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        prompt = "The capital of France is"
        if token_input:
            prompt_input = self.tokenizer.encode(prompt)
            num_prompt_tokens = len(prompt_input)
        else:
            prompt_input = prompt
            num_prompt_tokens = len(self.tokenizer.encode(prompt))

        if use_list_input:
            prompt_arg = [prompt_input, prompt_input]
            num_prompts = len(prompt_arg)
        else:
            prompt_arg = prompt_input
            num_prompts = 1

        response = client.embeddings.create(
            input=prompt_arg,
            model=self.model,
        )

        assert len(response.data) == num_prompts
        assert isinstance(response.data, list)
        assert response.data[0].embedding
        assert response.data[0].index is not None
        assert response.data[0].object == "embedding"
        assert response.model == self.model
        assert response.object == "list"
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
        assert (
            response.usage.total_tokens == num_prompt_tokens
        ), f"{response.usage.total_tokens} vs {num_prompt_tokens}"

    def run_batch(self):
        # FIXME not implemented
        pass

    def test_embedding(self):
        # TODO the fields of encoding_format, dimensions, user are skipped
        # TODO support use_list_input
        for use_list_input in [False]:
            for token_input in [False, True]:
                self.run_embedding(use_list_input, token_input)

    def test_batch(self):
        self.run_batch()


if __name__ == "__main__":
    unittest.main(warnings="ignore")

    # t = TestOpenAIServer()
    # t.setUpClass()
    # t.test_embedding()
    # t.tearDownClass()