test_serving_completions.py 4.09 KB
Newer Older
1
"""
2
3
4
Unit-tests for the refactored completions-serving handler (no pytest).
Run with:
    python -m unittest tests.test_serving_completions_unit -v
5
6
"""

7
import unittest
8
9
from unittest.mock import AsyncMock, Mock, patch

10
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
11
12
13
14
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.tokenizer_manager import TokenizerManager


15
16
class ServingCompletionTestCase(unittest.TestCase):
    """Bundle all prompt/echo tests in one TestCase."""
17

18
19
20
21
    # ---------- shared test fixtures ----------
    def setUp(self):
        # build the mock TokenizerManager once for every test
        tm = Mock(spec=TokenizerManager)
22

23
24
25
26
        tm.tokenizer = Mock()
        tm.tokenizer.encode.return_value = [1, 2, 3, 4]
        tm.tokenizer.decode.return_value = "decoded text"
        tm.tokenizer.bos_token_id = 1
27

28
29
        tm.model_config = Mock(is_multimodal=False)
        tm.server_args = Mock(enable_cache_report=False)
30

31
32
        tm.generate_request = AsyncMock()
        tm.create_abort_task = Mock()
33

34
        self.sc = OpenAIServingCompletion(tm)
35

36
37
38
39
40
    # ---------- prompt-handling ----------
    def test_single_string_prompt(self):
        req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
        internal, _ = self.sc._convert_to_internal_request([req], ["id"])
        self.assertEqual(internal.text, "Hello world")
41

42
43
44
45
    def test_single_token_ids_prompt(self):
        req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
        internal, _ = self.sc._convert_to_internal_request([req], ["id"])
        self.assertEqual(internal.input_ids, [1, 2, 3, 4])
46

47
48
49
    def test_completion_template_handling(self):
        req = CompletionRequest(
            model="x", prompt="def f():", suffix="return 1", max_tokens=100
50
51
52
53
        )
        with patch(
            "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
            return_value=True,
54
55
56
        ), patch(
            "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
            return_value="processed_prompt",
57
        ):
58
59
            internal, _ = self.sc._convert_to_internal_request([req], ["id"])
            self.assertEqual(internal.text, "processed_prompt")
60

61
62
63
64
    # ---------- echo-handling ----------
    def test_echo_with_string_prompt_streaming(self):
        req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
        self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
65

66
67
68
    def test_echo_with_list_of_strings_streaming(self):
        req = CompletionRequest(
            model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
69
        )
70
71
        self.assertEqual(self.sc._get_echo_text(req, 0), "A")
        self.assertEqual(self.sc._get_echo_text(req, 1), "B")
72

73
74
75
76
    def test_echo_with_token_ids_streaming(self):
        req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
        self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
        self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
77

78
79
80
    def test_echo_with_multiple_token_ids_streaming(self):
        req = CompletionRequest(
            model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
81
        )
82
83
        self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
        self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
84

85
86
87
88
    def test_prepare_echo_prompts_non_streaming(self):
        # single string
        req = CompletionRequest(model="x", prompt="Hi", echo=True)
        self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
89

90
91
92
        # list of strings
        req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
        self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
93

94
95
96
97
        # token IDs
        req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
        self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
        self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
98
99


100
101
if __name__ == "__main__":
    unittest.main(verbosity=2)