test_fim_completion.py 2.2 KB
Newer Older
1
2
3
4
5
import unittest

import openai

from sglang.srt.utils import kill_process_tree
6
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
7
8
9
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
10
    CustomTestCase,
11
12
13
14
    popen_launch_server,
)


15
class TestFimCompletion(CustomTestCase):
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    @classmethod
    def setUpClass(cls):
        cls.model = "deepseek-ai/deepseek-coder-1.3b-base"
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.api_key = "sk-123456"
        other_args = ["--completion-template", "deepseek_coder"]
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
            other_args=other_args,
        )
        cls.base_url += "/v1"
        cls.tokenizer = get_tokenizer(cls.model)

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def run_fim_completion(self, number_of_completion):
        client = openai.Client(api_key=self.api_key, base_url=self.base_url)
        prompt = "function sum(a: number, b: number): number{\n"
        suffix = "}"

        prompt_input = self.tokenizer.encode(prompt) + self.tokenizer.encode(suffix)
        num_prompt_tokens = len(prompt_input) + 2

        response = client.completions.create(
            model=self.model,
            prompt=prompt,
            suffix=suffix,
            temperature=0.3,
            max_tokens=32,
            stream=False,
            n=number_of_completion,
        )

        print(response)
        print(len(response.choices))
        assert len(response.choices) == number_of_completion
        assert response.id
        assert response.created
        assert response.object == "text_completion"
        assert (
            response.usage.prompt_tokens == num_prompt_tokens
        ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
        assert response.usage.completion_tokens > 0
        assert response.usage.total_tokens > 0

    def test_fim_completion(self):
        for number_of_completion in [1, 3]:
            self.run_fim_completion(number_of_completion)


if __name__ == "__main__":
    unittest.main()