Unverified Commit 0936c766 authored by Xiaotong Jiang's avatar Xiaotong Jiang Committed by GitHub
Browse files

Fix kimi k2 function calling format (#9606)

parent 0ef583b7
...@@ -835,15 +835,23 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -835,15 +835,23 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason["matched"] = None finish_reason["matched"] = None
try: try:
text, call_info_list = parser.parse_non_stream(text) text, call_info_list = parser.parse_non_stream(text)
tool_calls = [ tool_calls = []
ToolCall( for call_info in call_info_list:
id=f"call_{uuid.uuid4().hex[:24]}", # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
function=FunctionResponse( if tool_call_parser == "kimi_k2" and call_info.name is not None:
name=call_info.name, arguments=call_info.parameters tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
), else:
tool_id = f"call_{uuid.uuid4().hex[:24]}"
tool_calls.append(
ToolCall(
id=tool_id,
index=getattr(call_info, "tool_index", None),
function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters
),
)
) )
for call_info in call_info_list
]
return tool_calls, text, finish_reason return tool_calls, text, finish_reason
except Exception as e: except Exception as e:
logger.error(f"Tool call parsing error: {e}") logger.error(f"Tool call parsing error: {e}")
...@@ -954,7 +962,11 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -954,7 +962,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call # Tool call ID should be generated only once per tool call
if call_item.name: if call_item.name:
# First chunk: include ID and function name # First chunk: include ID and function name
tool_call_id = f"call_{uuid.uuid4().hex[:24]}" if self.tokenizer_manager.server_args.tool_call_parser == "kimi_k2":
# Align with Kimi-K2 format: functions.{name}:{index}
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
else:
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name function_name = call_item.name
else: else:
# Subsequent chunks: null ID and name for argument deltas # Subsequent chunks: null ID and name for argument deltas
......
...@@ -6,6 +6,8 @@ or ...@@ -6,6 +6,8 @@ or
python -m unittest discover -s tests -p "test_*unit.py" -v python -m unittest discover -s tests -p "test_*unit.py" -v
""" """
import asyncio
import json
import unittest import unittest
import uuid import uuid
from typing import Optional from typing import Optional
...@@ -325,6 +327,100 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -325,6 +327,100 @@ class ServingChatTestCase(unittest.TestCase):
result, "Should return None when parser has no tool call data" result, "Should return None when parser has no tool call data"
) )
# ------------- kimi_k2 tool_call_id formatting -------------
def test_kimi_k2_non_streaming_tool_call_id_format(self):
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser
self.tm.server_args.tool_call_parser = "kimi_k2"
# Mock FunctionCallParser.parse_non_stream to return one tool call
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# Build a mock ToolCallItem-like object
call_info = Mock()
call_info.name = "get_weather"
call_info.parameters = '{"city":"Paris"}'
call_info.tool_index = 0
parser_instance.has_tool_call.return_value = True
parser_instance.parse_non_stream.return_value = ("", [call_info])
finish_reason = {"type": "stop", "matched": None}
tools = [
{"type": "function", "function": {"name": "get_weather"}},
]
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
text="<|tool_calls_section_begin|>...",
tools=tools,
tool_call_parser="kimi_k2",
finish_reason=finish_reason,
)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
self.assertEqual(tool_calls[0].function.name, "get_weather")
def test_kimi_k2_streaming_tool_call_id_format(self):
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser
self.tm.server_args.tool_call_parser = "kimi_k2"
# Prepare request with tools
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=True,
)
# Patch FunctionCallParser used inside _process_tool_call_stream
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# First call returns one ToolCallItem-like chunk (with name)
first_chunk_call = Mock()
first_chunk_call.tool_index = 0
first_chunk_call.name = "get_weather"
first_chunk_call.parameters = ""
parser_instance.parse_stream_chunk.side_effect = [
("", [first_chunk_call]),
("", []),
]
async def collect_first_tool_chunk():
gen = self.chat._process_tool_call_stream(
index=0,
delta="irrelevant",
parser_dict={},
content={"meta_info": {"id": "chatcmpl-test"}},
request=req,
has_tool_calls={},
)
# Get first yielded SSE line
line = None
async for emitted in gen:
line = emitted
break
return line
loop = asyncio.get_event_loop()
line = loop.run_until_complete(collect_first_tool_chunk())
self.assertIsNotNone(line)
self.assertTrue(line.startswith("data: "))
payload = json.loads(line[len("data: ") :])
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment