test_serving_chat.py 7.77 KB
Newer Older
1
"""
2
3
4
5
6
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
    python tests/test_serving_chat_unit.py -v
or
    python -m unittest discover -s tests -p "test_*unit.py" -v
7
8
"""

9
import unittest
10
import uuid
11
from typing import Optional
12
13
14
15
from unittest.mock import Mock, patch

from fastapi import Request

16
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
17
18
19
20
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput


21
22
23
class _MockTokenizerManager:
    """Minimal mock that satisfies OpenAIServingChat."""

24
    def __init__(self):
25
26
27
28
29
30
31
        self.model_config = Mock(is_multimodal=False)
        self.server_args = Mock(
            enable_cache_report=False,
            tool_call_parser="hermes",
            reasoning_parser=None,
        )
        self.chat_template_name: Optional[str] = "llama-3"
32

33
        # tokenizer stub
34
        self.tokenizer = Mock()
35
36
        self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
        self.tokenizer.decode.return_value = "Test response"
37
38
39
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

40
41
        # async generator stub for generate_request
        async def _mock_generate():
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            yield {
                "text": "Test response",
                "meta_info": {
                    "id": f"chatcmpl-{uuid.uuid4()}",
                    "prompt_tokens": 10,
                    "completion_tokens": 5,
                    "cached_tokens": 0,
                    "finish_reason": {"type": "stop", "matched": None},
                    "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
                    "output_top_logprobs": None,
                },
                "index": 0,
            }

56
57
        self.generate_request = Mock(return_value=_mock_generate())
        self.create_abort_task = Mock()
58
59


60
61
62
63
64
class ServingChatTestCase(unittest.TestCase):
    # ------------- common fixtures -------------
    def setUp(self):
        self.tm = _MockTokenizerManager()
        self.chat = OpenAIServingChat(self.tm)
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        # frequently reused requests
        self.basic_req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi?"}],
            temperature=0.7,
            max_tokens=100,
            stream=False,
        )
        self.stream_req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi?"}],
            temperature=0.7,
            max_tokens=100,
            stream=True,
        )
81

82
83
        self.fastapi_request = Mock(spec=Request)
        self.fastapi_request.headers = {}
84

85
86
    # ------------- conversion tests -------------
    def test_convert_to_internal_request_single(self):
87
88
        with patch(
            "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
            conv_ins = Mock()
            conv_ins.get_prompt.return_value = "Test prompt"
            conv_ins.image_data = conv_ins.audio_data = None
            conv_ins.modalities = []
            conv_ins.stop_str = ["</s>"]
            conv_mock.return_value = conv_ins

            proc_mock.return_value = (
                "Test prompt",
                [1, 2, 3],
                None,
                None,
                [],
                ["</s>"],
                None,
            )
106

107
            adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
108
109
110
            self.assertIsInstance(adapted, GenerateReqInput)
            self.assertFalse(adapted.stream)
            self.assertEqual(processed, self.basic_req)
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    # # ------------- tool-call branch -------------
    # def test_tool_call_request_conversion(self):
    #     req = ChatCompletionRequest(
    #         model="x",
    #         messages=[{"role": "user", "content": "Weather?"}],
    #         tools=[
    #             {
    #                 "type": "function",
    #                 "function": {
    #                     "name": "get_weather",
    #                     "parameters": {"type": "object", "properties": {}},
    #                 },
    #             }
    #         ],
    #         tool_choice="auto",
    #     )

    #     with patch.object(
    #         self.chat,
    #         "_process_messages",
    #         return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
    #     ):
    #         adapted, _ = self.chat._convert_to_internal_request(req, "rid")
    #         self.assertEqual(adapted.rid, "rid")

    # def test_tool_choice_none(self):
    #     req = ChatCompletionRequest(
    #         model="x",
    #         messages=[{"role": "user", "content": "Hi"}],
    #         tools=[{"type": "function", "function": {"name": "noop"}}],
    #         tool_choice="none",
    #     )
    #     with patch.object(
    #         self.chat,
    #         "_process_messages",
    #         return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
    #     ):
    #         adapted, _ = self.chat._convert_to_internal_request(req, "rid")
    #         self.assertEqual(adapted.rid, "rid")
151
152
153
154
155
156
157

    # ------------- multimodal branch -------------
    def test_multimodal_request_with_images(self):
        self.tm.model_config.is_multimodal = True

        req = ChatCompletionRequest(
            model="x",
158
159
160
161
            messages=[
                {
                    "role": "user",
                    "content": [
162
                        {"type": "text", "text": "What's in the image?"},
163
164
                        {
                            "type": "image_url",
165
                            "image_url": {"url": "data:image/jpeg;base64,"},
166
167
168
169
170
171
                        },
                    ],
                }
            ],
        )

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        with patch.object(
            self.chat,
            "_apply_jinja_template",
            return_value=("prompt", [1, 2], ["img"], None, [], []),
        ), patch.object(
            self.chat,
            "_apply_conversation_template",
            return_value=("prompt", ["img"], None, [], []),
        ):
            out = self.chat._process_messages(req, True)
            _, _, image_data, *_ = out
            self.assertEqual(image_data, ["img"])

    # ------------- template handling -------------
    def test_jinja_template_processing(self):
        req = ChatCompletionRequest(
            model="x", messages=[{"role": "user", "content": "Hello"}]
189
        )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        self.tm.chat_template_name = None
        self.tm.tokenizer.chat_template = "<jinja>"

        with patch.object(
            self.chat,
            "_apply_jinja_template",
            return_value=("processed", [1], None, None, [], ["</s>"]),
        ), patch("builtins.hasattr", return_value=True):
            prompt, prompt_ids, *_ = self.chat._process_messages(req, False)
            self.assertEqual(prompt, "processed")
            self.assertEqual(prompt_ids, [1])

    # ------------- sampling-params -------------
    def test_sampling_param_build(self):
        req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi"}],
207
208
209
210
            temperature=0.8,
            max_tokens=150,
            min_tokens=5,
            top_p=0.9,
211
            stop=["</s>"],
212
        )
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        with patch.object(
            self.chat,
            "_process_messages",
            return_value=("Prompt", [1], None, None, [], ["</s>"], None),
        ):
            params = self.chat._build_sampling_params(req, ["</s>"], None)
            self.assertEqual(params["temperature"], 0.8)
            self.assertEqual(params["max_new_tokens"], 150)
            self.assertEqual(params["min_new_tokens"], 5)
            self.assertEqual(params["stop"], ["</s>"])


if __name__ == "__main__":
    unittest.main(verbosity=2)