Unverified Commit f3b5db6e authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Feat: support disable tool parser (#10184)

parent 2286e85e
...@@ -53,6 +53,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -53,6 +53,7 @@ class OpenAIServingChat(OpenAIServingBase):
): ):
super().__init__(tokenizer_manager) super().__init__(tokenizer_manager)
self.template_manager = template_manager self.template_manager = template_manager
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "chatcmpl-" return "chatcmpl-"
...@@ -172,10 +173,11 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -172,10 +173,11 @@ class OpenAIServingChat(OpenAIServingBase):
] ]
else: else:
tools = [item.function.model_dump() for item in request.tools] tools = [item.function.model_dump() for item in request.tools]
if self.tool_call_parser:
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser parser = FunctionCallParser(request.tools, self.tool_call_parser)
parser = FunctionCallParser(request.tools, tool_call_parser) tool_call_constraint = parser.get_structure_constraint(
tool_call_constraint = parser.get_structure_constraint(request.tool_choice) request.tool_choice
)
# Use chat template # Use chat template
if self.template_manager.chat_template_name is None: if self.template_manager.chat_template_name is None:
...@@ -537,7 +539,11 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -537,7 +539,11 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
# Handle tool calls # Handle tool calls
if request.tool_choice != "none" and request.tools: if (
request.tool_choice != "none"
and request.tools
and self.tool_call_parser
):
async for chunk in self._process_tool_call_stream( async for chunk in self._process_tool_call_stream(
index, index,
delta, delta,
...@@ -727,10 +733,13 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -727,10 +733,13 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle tool calls # Handle tool calls
tool_calls = None tool_calls = None
if request.tool_choice != "none" and request.tools: if (
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser request.tool_choice != "none"
and request.tools
and self.tool_call_parser
):
tool_calls, text, finish_reason = self._process_tool_calls( tool_calls, text, finish_reason = self._process_tool_calls(
text, request.tools, tool_call_parser, finish_reason text, request.tools, finish_reason
) )
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
...@@ -824,11 +833,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -824,11 +833,10 @@ class OpenAIServingChat(OpenAIServingBase):
self, self,
text: str, text: str,
tools: List[Any], tools: List[Any],
tool_call_parser: Optional[str],
finish_reason: Dict[str, Any], finish_reason: Dict[str, Any],
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]: ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response""" """Process tool calls in the response"""
parser = FunctionCallParser(tools, tool_call_parser) parser = FunctionCallParser(tools, self.tool_call_parser)
if parser.has_tool_call(text): if parser.has_tool_call(text):
if finish_reason["type"] == "stop": if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls" finish_reason["type"] = "tool_calls"
...@@ -838,7 +846,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -838,7 +846,10 @@ class OpenAIServingChat(OpenAIServingBase):
tool_calls = [] tool_calls = []
for call_info in call_info_list: for call_info in call_info_list:
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index} # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
if tool_call_parser == "kimi_k2" and call_info.name is not None: if (
self.tool_call_parser == "kimi_k2"
and call_info.name is not None
):
tool_id = f"functions.{call_info.name}:{call_info.tool_index}" tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
else: else:
tool_id = f"call_{uuid.uuid4().hex[:24]}" tool_id = f"call_{uuid.uuid4().hex[:24]}"
...@@ -933,7 +944,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -933,7 +944,7 @@ class OpenAIServingChat(OpenAIServingBase):
if index not in parser_dict: if index not in parser_dict:
parser_dict[index] = FunctionCallParser( parser_dict[index] = FunctionCallParser(
tools=request.tools, tools=request.tools,
tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser, tool_call_parser=self.tool_call_parser,
) )
parser = parser_dict[index] parser = parser_dict[index]
...@@ -962,7 +973,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -962,7 +973,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call # Tool call ID should be generated only once per tool call
if call_item.name: if call_item.name:
# First chunk: include ID and function name # First chunk: include ID and function name
if self.tokenizer_manager.server_args.tool_call_parser == "kimi_k2": if self.tool_call_parser == "kimi_k2":
# Align with Kimi-K2 format: functions.{name}:{index} # Align with Kimi-K2 format: functions.{name}:{index}
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}" tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
else: else:
......
...@@ -332,7 +332,7 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -332,7 +332,7 @@ class ServingChatTestCase(unittest.TestCase):
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser.""" """Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser # Force kimi_k2 parser
self.tm.server_args.tool_call_parser = "kimi_k2" self.chat.tool_call_parser = "kimi_k2"
# Mock FunctionCallParser.parse_non_stream to return one tool call # Mock FunctionCallParser.parse_non_stream to return one tool call
with patch( with patch(
...@@ -357,7 +357,6 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -357,7 +357,6 @@ class ServingChatTestCase(unittest.TestCase):
tool_calls, remaining_text, _ = self.chat._process_tool_calls( tool_calls, remaining_text, _ = self.chat._process_tool_calls(
text="<|tool_calls_section_begin|>...", text="<|tool_calls_section_begin|>...",
tools=tools, tools=tools,
tool_call_parser="kimi_k2",
finish_reason=finish_reason, finish_reason=finish_reason,
) )
...@@ -370,7 +369,7 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -370,7 +369,7 @@ class ServingChatTestCase(unittest.TestCase):
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser.""" """Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser # Force kimi_k2 parser
self.tm.server_args.tool_call_parser = "kimi_k2" self.chat.tool_call_parser = "kimi_k2"
# Prepare request with tools # Prepare request with tools
req = ChatCompletionRequest( req = ChatCompletionRequest(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment