"docs/source/vscode:/vscode.git/clone" did not exist on "bc8f8b0b7a90d565ff20d08088e1830a61c19639"
Unverified Commit 0ac61146 authored by eraser00's avatar eraser00 Committed by GitHub
Browse files

Replace the Kimi-K2 generated tool call idx with history tool call count (#10612)


Co-authored-by: default avatareraser00 <eraser00@github.com>
parent 7dcd689b
...@@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import ( ...@@ -33,6 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret, process_hidden_states_from_ret,
to_openai_style_logprobs, to_openai_style_logprobs,
) )
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv from sglang.srt.parser.conversation import generate_chat_conv
...@@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -749,8 +750,9 @@ class OpenAIServingChat(OpenAIServingBase):
and request.tools and request.tools
and self.tool_call_parser and self.tool_call_parser
): ):
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls( tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, finish_reason text, request.tools, finish_reason, history_tool_calls_cnt
) )
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
...@@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -840,11 +842,32 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True) token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs) return ChoiceLogprobs(content=token_logprobs)
def _process_tool_call_id(
self,
call_item: ToolCallItem,
history_tool_calls_cnt: int,
) -> str:
"""Process for generating a new and unique `tool_call_id`"""
if self.tool_call_parser != "kimi_k2":
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
return tool_call_id
else:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
logger.debug(
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
)
return tool_call_id
def _process_tool_calls( def _process_tool_calls(
self, self,
text: str, text: str,
tools: List[Any], tools: List[Any],
finish_reason: Dict[str, Any], finish_reason: Dict[str, Any],
history_tool_calls_cnt: int = 0,
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]: ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response""" """Process tool calls in the response"""
parser = FunctionCallParser(tools, self.tool_call_parser) parser = FunctionCallParser(tools, self.tool_call_parser)
...@@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -856,15 +879,9 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text) text, call_info_list = parser.parse_non_stream(text)
tool_calls = [] tool_calls = []
for call_info in call_info_list: for call_info in call_info_list:
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index} tool_id = self._process_tool_call_id(
if ( call_info, history_tool_calls_cnt
self.tool_call_parser == "kimi_k2" )
and call_info.name is not None
):
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
else:
tool_id = f"call_{uuid.uuid4().hex[:24]}"
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
id=tool_id, id=tool_id,
...@@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -920,6 +937,26 @@ class OpenAIServingChat(OpenAIServingBase):
reasoning_parser = reasoning_parser_dict[index] reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta) return reasoning_parser.parse_stream_chunk(delta)
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool: def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs. """Extracts the 'enable_thinking' flag from request chat_template_kwargs.
...@@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -977,6 +1014,7 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls # Yield tool calls
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
for call_item in calls: for call_item in calls:
# Mark that this choice has tool calls # Mark that this choice has tool calls
has_tool_calls[index] = True has_tool_calls[index] = True
...@@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -984,11 +1022,9 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call # Tool call ID should be generated only once per tool call
if call_item.name: if call_item.name:
# First chunk: include ID and function name # First chunk: include ID and function name
if self.tool_call_parser == "kimi_k2": tool_call_id = self._process_tool_call_id(
# Align with Kimi-K2 format: functions.{name}:{index} call_item, history_tool_calls_cnt
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}" )
else:
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name function_name = call_item.name
else: else:
# Subsequent chunks: null ID and name for argument deltas # Subsequent chunks: null ID and name for argument deltas
......
...@@ -420,6 +420,181 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -420,6 +420,181 @@ class ServingChatTestCase(unittest.TestCase):
tool_calls = payload["choices"][0]["delta"]["tool_calls"] tool_calls = payload["choices"][0]["delta"]["tool_calls"]
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0") self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
def test_kimi_k2_non_streaming_tool_call_id_with_history(self):
"""Ensure non-streaming tool_call.id increase with tool calls history for kimi_k2 parser."""
# Force kimi_k2 parser
self.chat.tool_call_parser = "kimi_k2"
# Prepare request with tool calls history
req = ChatCompletionRequest(
model="x",
messages=[
{"role": "user", "content": "What's the weather today in paris?"},
{
"role": "assistant",
"content": "Let me do some search first.",
"tool_calls": [
{
"id": "functions.get_weather:0",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": "It's rainy in paris now.",
"tool_call_id": "functions.get_weather:0",
},
{
"role": "assistant",
"content": "It's rainy now.",
},
{
"role": "user",
"content": "What about LA and Tokyo?",
},
],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=False,
)
# Mock FunctionCallParser.parse_non_stream to return one tool call
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# Build a mock ToolCallItem-like object
call_info = Mock()
call_info.name = "get_weather"
call_info.parameters = '{"city":"Loa Angeles"}'
# Kimi-K2 series models might generate fixed number tool_indx,
# ignoring the tool calls history and mess up all the following tool calls
call_info.tool_index = 0
call_info2 = Mock()
call_info2.name = "get_weather"
call_info2.parameters = '{"city":"Tokyo"}'
call_info2.tool_index = 1
parser_instance.has_tool_call.return_value = True
parser_instance.parse_non_stream.return_value = (
"",
[call_info, call_info2],
)
finish_reason = {"type": "stop", "matched": None}
tools = [
{"type": "function", "function": {"name": "get_weather"}},
]
history_tool_calls_cnt = self.chat._get_history_tool_calls_cnt(req)
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
text="<|tool_calls_section_begin|>...",
tools=tools,
finish_reason=finish_reason,
history_tool_calls_cnt=history_tool_calls_cnt,
)
self.assertEqual(history_tool_calls_cnt, 1)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0].id, "functions.get_weather:1")
self.assertEqual(tool_calls[0].function.name, "get_weather")
self.assertEqual(tool_calls[1].id, "functions.get_weather:2")
self.assertEqual(tool_calls[1].function.name, "get_weather")
def test_kimi_k2_streaming_tool_call_id_with_history(self):
"""Ensure streaming first chunk tool_call.id increase with tool calls history for kimi_k2 parser."""
# Force kimi_k2 parser
self.chat.tool_call_parser = "kimi_k2"
# Prepare request with tool calls history
req = ChatCompletionRequest(
model="x",
messages=[
{"role": "user", "content": "What's the weather today in paris?"},
{
"role": "assistant",
"content": "Let me do some search first.",
"tool_calls": [
{
"id": "functions.get_weather:0",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": "It's rainy in paris now.",
"tool_call_id": "functions.get_weather:0",
},
{
"role": "assistant",
"content": "It's rainy now.",
},
{
"role": "user",
"content": "What about LA?",
},
],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=True,
)
# Patch FunctionCallParser used inside _process_tool_call_stream
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# First call returns one ToolCallItem-like chunk (with name)
first_chunk_call = Mock()
# Kimi-K2 series models might generate fixed number tool_indx,
# ignoring the tool calls history and mess up all the following tool calls
first_chunk_call.tool_index = 0
first_chunk_call.name = "get_weather"
first_chunk_call.parameters = ""
parser_instance.parse_stream_chunk.side_effect = [
("", [first_chunk_call]),
("", []),
]
async def collect_first_tool_chunk():
gen = self.chat._process_tool_call_stream(
index=0,
delta="irrelevant",
parser_dict={},
content={"meta_info": {"id": "chatcmpl-test"}},
request=req,
has_tool_calls={},
)
# Get first yielded SSE line
line = None
async for emitted in gen:
line = emitted
break
return line
loop = asyncio.get_event_loop()
line = loop.run_until_complete(collect_first_tool_chunk())
self.assertIsNotNone(line)
self.assertTrue(line.startswith("data: "))
payload = json.loads(line[len("data: ") :])
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:1")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment