test_serving_chat.py 4.89 KB
Newer Older
1
"""
2
3
4
5
6
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
    python tests/test_serving_chat_unit.py -v
or
    python -m unittest discover -s tests -p "test_*unit.py" -v
7
8
"""

9
import unittest
10
import uuid
11
from typing import Optional
12
13
14
15
from unittest.mock import Mock, patch

from fastapi import Request

16
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
17
18
19
20
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput


21
22
23
class _MockTokenizerManager:
    """Minimal mock that satisfies OpenAIServingChat."""

24
    def __init__(self):
25
26
27
28
29
30
31
        self.model_config = Mock(is_multimodal=False)
        self.server_args = Mock(
            enable_cache_report=False,
            tool_call_parser="hermes",
            reasoning_parser=None,
        )
        self.chat_template_name: Optional[str] = "llama-3"
32

33
        # tokenizer stub
34
        self.tokenizer = Mock()
35
36
        self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
        self.tokenizer.decode.return_value = "Test response"
37
38
39
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

40
41
        # async generator stub for generate_request
        async def _mock_generate():
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            yield {
                "text": "Test response",
                "meta_info": {
                    "id": f"chatcmpl-{uuid.uuid4()}",
                    "prompt_tokens": 10,
                    "completion_tokens": 5,
                    "cached_tokens": 0,
                    "finish_reason": {"type": "stop", "matched": None},
                    "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
                    "output_top_logprobs": None,
                },
                "index": 0,
            }

56
57
        self.generate_request = Mock(return_value=_mock_generate())
        self.create_abort_task = Mock()
58
59


60
61
62
63
64
65
66
67
68
class _MockTemplateManager:
    """Minimal mock for TemplateManager."""

    def __init__(self):
        self.chat_template_name: Optional[str] = "llama-3"
        self.jinja_template_content_format: Optional[str] = None
        self.completion_template_name: Optional[str] = None


69
70
71
72
class ServingChatTestCase(unittest.TestCase):
    # ------------- common fixtures -------------
    def setUp(self):
        self.tm = _MockTokenizerManager()
73
74
        self.template_manager = _MockTemplateManager()
        self.chat = OpenAIServingChat(self.tm, self.template_manager)
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        # frequently reused requests
        self.basic_req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi?"}],
            temperature=0.7,
            max_tokens=100,
            stream=False,
        )
        self.stream_req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi?"}],
            temperature=0.7,
            max_tokens=100,
            stream=True,
        )
91

92
93
        self.fastapi_request = Mock(spec=Request)
        self.fastapi_request.headers = {}
94

95
96
    # ------------- conversion tests -------------
    def test_convert_to_internal_request_single(self):
97
98
        with patch(
            "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
            conv_ins = Mock()
            conv_ins.get_prompt.return_value = "Test prompt"
            conv_ins.image_data = conv_ins.audio_data = None
            conv_ins.modalities = []
            conv_ins.stop_str = ["</s>"]
            conv_mock.return_value = conv_ins

            proc_mock.return_value = (
                "Test prompt",
                [1, 2, 3],
                None,
                None,
                [],
                ["</s>"],
                None,
            )
116

117
            adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
118
119
120
            self.assertIsInstance(adapted, GenerateReqInput)
            self.assertFalse(adapted.stream)
            self.assertEqual(processed, self.basic_req)
121

122
123
124
125
126
    # ------------- sampling-params -------------
    def test_sampling_param_build(self):
        req = ChatCompletionRequest(
            model="x",
            messages=[{"role": "user", "content": "Hi"}],
127
128
129
130
            temperature=0.8,
            max_tokens=150,
            min_tokens=5,
            top_p=0.9,
131
            stop=["</s>"],
132
        )
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        with patch.object(
            self.chat,
            "_process_messages",
            return_value=("Prompt", [1], None, None, [], ["</s>"], None),
        ):
            params = self.chat._build_sampling_params(req, ["</s>"], None)
            self.assertEqual(params["temperature"], 0.8)
            self.assertEqual(params["max_new_tokens"], 150)
            self.assertEqual(params["min_new_tokens"], 5)
            self.assertEqual(params["stop"], ["</s>"])


if __name__ == "__main__":
    unittest.main(verbosity=2)