test_serving_chat.py 7.74 KB
Newer Older
1
"""
2
3
4
5
6
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
    python tests/test_serving_chat_unit.py -v
or
    python -m unittest discover -s tests -p "test_*unit.py" -v
7
8
"""

9
import unittest
10
import uuid
11
from typing import Optional
12
13
14
15
from unittest.mock import Mock, patch

from fastapi import Request

16
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
17
18
19
20
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput


21
22
23
class _MockTokenizerManager:
    """Minimal mock that satisfies OpenAIServingChat."""

24
    def __init__(self):
25
26
27
28
29
30
31
        self.model_config = Mock(is_multimodal=False)
        self.server_args = Mock(
            enable_cache_report=False,
            tool_call_parser="hermes",
            reasoning_parser=None,
        )
        self.chat_template_name: Optional[str] = "llama-3"
32

33
        # tokenizer stub
34
        self.tokenizer = Mock()
35
36
        self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
        self.tokenizer.decode.return_value = "Test response"
37
38
39
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

40
41
        # async generator stub for generate_request
        async def _mock_generate():
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            yield {
                "text": "Test response",
                "meta_info": {
                    "id": f"chatcmpl-{uuid.uuid4()}",
                    "prompt_tokens": 10,
                    "completion_tokens": 5,
                    "cached_tokens": 0,
                    "finish_reason": {"type": "stop", "matched": None},
                    "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
                    "output_top_logprobs": None,
                },
                "index": 0,
            }

56
57
        self.generate_request = Mock(return_value=_mock_generate())
        self.create_abort_task = Mock()
58
59


60
61
62
63
64
class ServingChatTestCase(unittest.TestCase):
    # ------------- common fixtures -------------
    def setUp(self):
        self.tm = _MockTokenizerManager()
        self.chat = OpenAIServingChat(self.tm)
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        # frequently reused requests
        self.basic_req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi?"}],
            temperature=0.7,
            max_tokens=100,
            stream=False,
        )
        self.stream_req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi?"}],
            temperature=0.7,
            max_tokens=100,
            stream=True,
        )
81

82
83
        self.fastapi_request = Mock(spec=Request)
        self.fastapi_request.headers = {}
84

85
86
    # ------------- conversion tests -------------
    def test_convert_to_internal_request_single(self):
87
88
        with patch(
            "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
            conv_ins = Mock()
            conv_ins.get_prompt.return_value = "Test prompt"
            conv_ins.image_data = conv_ins.audio_data = None
            conv_ins.modalities = []
            conv_ins.stop_str = ["</s>"]
            conv_mock.return_value = conv_ins

            proc_mock.return_value = (
                "Test prompt",
                [1, 2, 3],
                None,
                None,
                [],
                ["</s>"],
                None,
            )
106

107
108
109
110
111
112
            adapted, processed = self.chat._convert_to_internal_request(
                [self.basic_req], ["rid"]
            )
            self.assertIsInstance(adapted, GenerateReqInput)
            self.assertFalse(adapted.stream)
            self.assertEqual(processed, self.basic_req)
113

114
115
116
117
118
    # ------------- tool-call branch -------------
    def test_tool_call_request_conversion(self):
        req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Weather?"}],
119
120
121
122
123
            tools=[
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
124
                        "parameters": {"type": "object", "properties": {}},
125
126
127
128
129
130
                    },
                }
            ],
            tool_choice="auto",
        )

131
132
133
134
135
136
137
138
139
140
141
142
143
        with patch.object(
            self.chat,
            "_process_messages",
            return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
        ):
            adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
            self.assertEqual(adapted.rid, "rid")

    def test_tool_choice_none(self):
        req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi"}],
            tools=[{"type": "function", "function": {"name": "noop"}}],
144
145
            tool_choice="none",
        )
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        with patch.object(
            self.chat,
            "_process_messages",
            return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
        ):
            adapted, _ = self.chat._convert_to_internal_request([req], ["rid"])
            self.assertEqual(adapted.rid, "rid")

    # ------------- multimodal branch -------------
    def test_multimodal_request_with_images(self):
        self.tm.model_config.is_multimodal = True

        req = ChatCompletionRequest(
            model="x",
160
161
162
163
            messages=[
                {
                    "role": "user",
                    "content": [
164
                        {"type": "text", "text": "What's in the image?"},
165
166
                        {
                            "type": "image_url",
167
                            "image_url": {"url": "data:image/jpeg;base64,"},
168
169
170
171
172
173
                        },
                    ],
                }
            ],
        )

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        with patch.object(
            self.chat,
            "_apply_jinja_template",
            return_value=("prompt", [1, 2], ["img"], None, [], []),
        ), patch.object(
            self.chat,
            "_apply_conversation_template",
            return_value=("prompt", ["img"], None, [], []),
        ):
            out = self.chat._process_messages(req, True)
            _, _, image_data, *_ = out
            self.assertEqual(image_data, ["img"])

    # ------------- template handling -------------
    def test_jinja_template_processing(self):
        req = ChatCompletionRequest(
            model="x", messages=[{"role": "user", "content": "Hello"}]
191
        )
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        self.tm.chat_template_name = None
        self.tm.tokenizer.chat_template = "<jinja>"

        with patch.object(
            self.chat,
            "_apply_jinja_template",
            return_value=("processed", [1], None, None, [], ["</s>"]),
        ), patch("builtins.hasattr", return_value=True):
            prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
            self.assertEqual(prompt, "processed")
            self.assertEqual(prompt_ids, [1])

    # ------------- sampling-params -------------
    def test_sampling_param_build(self):
        req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi"}],
209
210
211
212
            temperature=0.8,
            max_tokens=150,
            min_tokens=5,
            top_p=0.9,
213
            stop=["</s>"],
214
        )
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        with patch.object(
            self.chat,
            "_process_messages",
            return_value=("Prompt", [1], None, None, [], ["</s>"], None),
        ):
            params = self.chat._build_sampling_params(req, ["</s>"], None)
            self.assertEqual(params["temperature"], 0.8)
            self.assertEqual(params["max_new_tokens"], 150)
            self.assertEqual(params["min_new_tokens"], 5)
            self.assertEqual(params["stop"], ["</s>"])


if __name__ == "__main__":
    unittest.main(verbosity=2)