""" Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'. Run with either: python tests/test_serving_chat_unit.py -v or python -m unittest discover -s tests -p "test_*unit.py" -v """ import asyncio import json import unittest import uuid from typing import Optional from unittest.mock import Mock, patch from fastapi import Request from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, MessageProcessingResult, ) from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.managers.io_struct import GenerateReqInput class _MockTokenizerManager: """Minimal mock that satisfies OpenAIServingChat.""" def __init__(self): self.model_config = Mock(is_multimodal=False) self.server_args = Mock( enable_cache_report=False, tool_call_parser="hermes", reasoning_parser=None, ) self.chat_template_name: Optional[str] = "llama-3" # tokenizer stub self.tokenizer = Mock() self.tokenizer.encode.return_value = [1, 2, 3, 4, 5] self.tokenizer.decode.return_value = "Test response" self.tokenizer.chat_template = None self.tokenizer.bos_token_id = 1 # async generator stub for generate_request async def _mock_generate(): yield { "text": "Test response", "meta_info": { "id": f"chatcmpl-{uuid.uuid4()}", "prompt_tokens": 10, "completion_tokens": 5, "cached_tokens": 0, "finish_reason": {"type": "stop", "matched": None}, "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")], "output_top_logprobs": None, }, "index": 0, } self.generate_request = Mock(return_value=_mock_generate()) self.create_abort_task = Mock() class _MockTemplateManager: """Minimal mock for TemplateManager.""" def __init__(self): self.chat_template_name: Optional[str] = "llama-3" self.jinja_template_content_format: Optional[str] = None self.completion_template_name: Optional[str] = None class ServingChatTestCase(unittest.TestCase): # ------------- common fixtures ------------- def setUp(self): self.tm = _MockTokenizerManager() self.template_manager = _MockTemplateManager() self.chat = OpenAIServingChat(self.tm, self.template_manager) # frequently reused requests self.basic_req = ChatCompletionRequest( model="x", messages=[{"role": "user", "content": "Hi?"}], temperature=0.7, max_tokens=100, stream=False, ) self.stream_req = ChatCompletionRequest( model="x", messages=[{"role": "user", "content": "Hi?"}], temperature=0.7, max_tokens=100, stream=True, ) self.fastapi_request = Mock(spec=Request) self.fastapi_request.headers = {} # ------------- conversion tests ------------- def test_convert_to_internal_request_single(self): with patch( "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock: conv_ins = Mock() conv_ins.get_prompt.return_value = "Test prompt" conv_ins.image_data = conv_ins.audio_data = None conv_ins.modalities = [] conv_ins.stop_str = [""] conv_mock.return_value = conv_ins proc_mock.return_value = MessageProcessingResult( "Test prompt", [1, 2, 3], None, None, [], [""], None, ) adapted, processed = self.chat._convert_to_internal_request(self.basic_req) self.assertIsInstance(adapted, GenerateReqInput) self.assertFalse(adapted.stream) self.assertEqual(processed, self.basic_req) def test_stop_str_isolation_between_requests(self): """Test that stop strings from one request don't affect subsequent requests. This tests the fix for the bug where conv.stop_str was being mutated globally, causing stop strings from one request to persist in subsequent requests. """ # Mock conversation template with initial stop_str initial_stop_str = ["\n"] with patch( "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" ) as conv_mock: # Create a mock conversation object that will be returned by generate_chat_conv conv_ins = Mock() conv_ins.get_prompt.return_value = "Test prompt" conv_ins.image_data = None conv_ins.audio_data = None conv_ins.modalities = [] conv_ins.stop_str = ( initial_stop_str.copy() ) # Template's default stop strings conv_mock.return_value = conv_ins # First request with additional stop string req1 = ChatCompletionRequest( model="x", messages=[{"role": "user", "content": "First request"}], stop=["CUSTOM_STOP"], ) # Call the actual _apply_conversation_template method (not mocked) result1 = self.chat._apply_conversation_template(req1, is_multimodal=False) # Verify first request has both stop strings expected_stop1 = initial_stop_str + ["CUSTOM_STOP"] self.assertEqual(result1.stop, expected_stop1) # Verify the original template's stop_str wasn't mutated after first request self.assertEqual(conv_ins.stop_str, initial_stop_str) # Second request without additional stop string req2 = ChatCompletionRequest( model="x", messages=[{"role": "user", "content": "Second request"}], # No custom stop strings ) result2 = self.chat._apply_conversation_template(req2, is_multimodal=False) # Verify second request only has original stop strings (no CUSTOM_STOP from req1) self.assertEqual(result2.stop, initial_stop_str) self.assertNotIn("CUSTOM_STOP", result2.stop) self.assertEqual(conv_ins.stop_str, initial_stop_str) # ------------- sampling-params ------------- def test_sampling_param_build(self): req = ChatCompletionRequest( model="x", messages=[{"role": "user", "content": "Hi"}], temperature=0.8, max_tokens=150, min_tokens=5, top_p=0.9, stop=[""], ) with patch.object( self.chat, "_process_messages", return_value=("Prompt", [1], None, None, [], [""], None), ): params = self.chat._build_sampling_params(req, [""], None) self.assertEqual(params["temperature"], 0.8) self.assertEqual(params["max_new_tokens"], 150) self.assertEqual(params["min_new_tokens"], 5) self.assertEqual(params["stop"], [""]) async def test_unstreamed_tool_args_completion(self): """Test that remaining tool call arguments are sent when generation finishes.""" # Mock FunctionCallParser with detector that has partial tool call data mock_parser = Mock() mock_detector = Mock() # Simulate a tool call that was partially streamed mock_detector.prev_tool_call_arr = [ { "name": "get_weather", "arguments": {"location": "San Francisco", "unit": "celsius"}, } ] mock_detector.streamed_args_for_tool = [ '{"location": "San Francisco"' # Partial arguments streamed so far ] mock_parser.detector = mock_detector content = { "meta_info": { "id": "chatcmpl-test123", } } request = ChatCompletionRequest( model="test", messages=[{"role": "user", "content": "What's the weather?"}], tools=[{"type": "function", "function": {"name": "get_weather"}}], ) # Test the completion method result = self.chat._check_for_unstreamed_tool_args( parser=mock_parser, content=content, request=request, finish_reason_type="stop", index=0, ) # Should return a chunk with remaining arguments self.assertIsNotNone(result, "Should return chunk with remaining arguments") self.assertIn('"arguments":', result, "Should contain arguments field") self.assertIn( ', "unit": "celsius"}', result, "Should contain remaining arguments" ) self.assertIn( '"finish_reason":null', result, "Should not include finish_reason in completion chunk", ) async def test_unstreamed_tool_args_no_completion_needed(self): """Test that no completion chunk is sent when all arguments were already streamed.""" # Mock FunctionCallParser with detector that has complete tool call data mock_parser = Mock() mock_detector = Mock() # Simulate a tool call that was completely streamed mock_detector.prev_tool_call_arr = [ {"name": "get_weather", "arguments": {"location": "San Francisco"}} ] mock_detector.streamed_args_for_tool = [ '{"location": "San Francisco"}' # All arguments already streamed ] mock_parser.detector = mock_detector content = { "meta_info": { "id": "chatcmpl-test123", } } request = ChatCompletionRequest( model="test", messages=[{"role": "user", "content": "What's the weather?"}], tools=[{"type": "function", "function": {"name": "get_weather"}}], ) # Test the completion method result = self.chat._check_for_unstreamed_tool_args( parser=mock_parser, content=content, request=request, finish_reason_type="stop", index=0, ) # Should return None since no completion is needed self.assertIsNone(result, "Should return None when no completion is needed") async def test_unstreamed_tool_args_no_parser_data(self): """Test that no completion chunk is sent when parser has no tool call data.""" # Mock FunctionCallParser with empty detector mock_parser = Mock() mock_detector = Mock() mock_detector.prev_tool_call_arr = [] mock_detector.streamed_args_for_tool = [] mock_parser.detector = mock_detector content = { "meta_info": { "id": "chatcmpl-test123", } } request = ChatCompletionRequest( model="test", messages=[{"role": "user", "content": "What's the weather?"}], tools=[{"type": "function", "function": {"name": "get_weather"}}], ) # Test the completion method result = self.chat._check_for_unstreamed_tool_args( parser=mock_parser, content=content, request=request, finish_reason_type="stop", index=0, ) # Should return None since there's no parser data self.assertIsNone( result, "Should return None when parser has no tool call data" ) # ------------- kimi_k2 tool_call_id formatting ------------- def test_kimi_k2_non_streaming_tool_call_id_format(self): """Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser.""" # Force kimi_k2 parser self.chat.tool_call_parser = "kimi_k2" # Mock FunctionCallParser.parse_non_stream to return one tool call with patch( "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" ) as ParserMock: parser_instance = ParserMock.return_value # Build a mock ToolCallItem-like object call_info = Mock() call_info.name = "get_weather" call_info.parameters = '{"city":"Paris"}' call_info.tool_index = 0 parser_instance.has_tool_call.return_value = True parser_instance.parse_non_stream.return_value = ("", [call_info]) finish_reason = {"type": "stop", "matched": None} tools = [ {"type": "function", "function": {"name": "get_weather"}}, ] tool_calls, remaining_text, finish_reason = self.chat._process_tool_calls( text="<|tool_calls_section_begin|>...", tools=tools, finish_reason=finish_reason, ) self.assertIsNotNone(tool_calls) self.assertEqual(len(tool_calls), 1) self.assertEqual(tool_calls[0].id, "functions.get_weather:0") self.assertEqual(tool_calls[0].function.name, "get_weather") def test_kimi_k2_streaming_tool_call_id_format(self): """Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser.""" # Force kimi_k2 parser self.chat.tool_call_parser = "kimi_k2" # Prepare request with tools req = ChatCompletionRequest( model="x", messages=[{"role": "user", "content": "Hi?"}], tools=[{"type": "function", "function": {"name": "get_weather"}}], stream=True, ) # Patch FunctionCallParser used inside _process_tool_call_stream with patch( "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" ) as ParserMock: parser_instance = ParserMock.return_value # First call returns one ToolCallItem-like chunk (with name) first_chunk_call = Mock() first_chunk_call.tool_index = 0 first_chunk_call.name = "get_weather" first_chunk_call.parameters = "" parser_instance.parse_stream_chunk.side_effect = [ ("", [first_chunk_call]), ("", []), ] async def collect_first_tool_chunk(): gen = self.chat._process_tool_call_stream( index=0, delta="irrelevant", parser_dict={}, content={"meta_info": {"id": "chatcmpl-test"}}, request=req, has_tool_calls={}, ) # Get first yielded SSE line line = None async for emitted in gen: line = emitted break return line loop = asyncio.get_event_loop() line = loop.run_until_complete(collect_first_tool_chunk()) self.assertIsNotNone(line) self.assertTrue(line.startswith("data: ")) payload = json.loads(line[len("data: ") :]) tool_calls = payload["choices"][0]["delta"]["tool_calls"] self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0") def test_kimi_k2_non_streaming_tool_call_id_with_history(self): """Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser.""" # Force kimi_k2 parser self.chat.tool_call_parser = "kimi_k2" # Prepare request with tool calls history req = ChatCompletionRequest( model="x", messages=[ {"role": "user", "content": "What's the weather today in paris?"}, { "role": "assistant", "content": "Let me do some search first.", "tool_calls": [ { "id": "functions.get_weather:0", "type": "function", "function": { "name": "get_weather", "arguments": '{"city": "Paris"}', }, } ], }, { "role": "tool", "content": "It's rainy in paris now.", "tool_call_id": "functions.get_weather:0", }, { "role": "assistant", "content": "It's rainy now.", }, { "role": "user", "content": "What about LA and Tokyo?", }, ], tools=[{"type": "function", "function": {"name": "get_weather"}}], stream=False, ) # Mock FunctionCallParser.parse_non_stream to return one tool call with patch( "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" ) as ParserMock: parser_instance = ParserMock.return_value # Build a mock ToolCallItem-like object call_info = Mock() call_info.name = "get_weather" call_info.parameters = '{"city":"Loa Angeles"}' # Kimi-K2 series models might generate fixed number tool_indx, # ignoring the tool calls history and mess up all the following tool calls call_info.tool_index = 0 call_info2 = Mock() call_info2.name = "get_weather" call_info2.parameters = '{"city":"Tokyo"}' call_info2.tool_index = 1 parser_instance.has_tool_call.return_value = True parser_instance.parse_non_stream.return_value = ( "", [call_info, call_info2], ) finish_reason = {"type": "stop", "matched": None} tools = [ {"type": "function", "function": {"name": "get_weather"}}, ] history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req) tool_calls, remaining_text, _ = self.chat._process_tool_calls( text="<|tool_calls_section_begin|>...", tools=tools, finish_reason=finish_reason, history_tool_calls_cnt=history_tool_calls_cnt, ) self.assertEqual(history_tool_calls_cnt, 1) self.assertIsNotNone(tool_calls) self.assertEqual(len(tool_calls), 2) self.assertEqual(tool_calls[0].id, "functions.get_weather:1") self.assertEqual(tool_calls[0].function.name, "get_weather") self.assertEqual(tool_calls[1].id, "functions.get_weather:2") self.assertEqual(tool_calls[1].function.name, "get_weather") def test_kimi_k2_streaming_tool_call_id_with_history(self): """Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser.""" # Force kimi_k2 parser self.chat.tool_call_parser = "kimi_k2" # Prepare request with tool calls history req = ChatCompletionRequest( model="x", messages=[ {"role": "user", "content": "What's the weather today in paris?"}, { "role": "assistant", "content": "Let me do some search first.", "tool_calls": [ { "id": "functions.get_weather:0", "type": "function", "function": { "name": "get_weather", "arguments": '{"city": "Paris"}', }, } ], }, { "role": "tool", "content": "It's rainy in paris now.", "tool_call_id": "functions.get_weather:0", }, { "role": "assistant", "content": "It's rainy now.", }, { "role": "user", "content": "What about LA?", }, ], tools=[{"type": "function", "function": {"name": "get_weather"}}], stream=True, ) # Patch FunctionCallParser used inside _process_tool_call_stream with patch( "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" ) as ParserMock: parser_instance = ParserMock.return_value # First call returns one ToolCallItem-like chunk (with name) first_chunk_call = Mock() # Kimi-K2 series models might generate fixed number tool_indx, # ignoring the tool calls history and mess up all the following tool calls first_chunk_call.tool_index = 0 first_chunk_call.name = "get_weather" first_chunk_call.parameters = "" parser_instance.parse_stream_chunk.side_effect = [ ("", [first_chunk_call]), ("", []), ] async def collect_first_tool_chunk(): gen = self.chat._process_tool_call_stream( index=0, delta="irrelevant", parser_dict={}, content={"meta_info": {"id": "chatcmpl-test"}}, request=req, has_tool_calls={}, ) # Get first yielded SSE line line = None async for emitted in gen: line = emitted break return line loop = asyncio.get_event_loop() line = loop.run_until_complete(collect_first_tool_chunk()) self.assertIsNotNone(line) self.assertTrue(line.startswith("data: ")) payload = json.loads(line[len("data: ") :]) tool_calls = payload["choices"][0]["delta"]["tool_calls"] self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1") if __name__ == "__main__": unittest.main(verbosity=2)