Unverified Commit ac3fae84 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Feature] Support "strict" in function calling (#4310)

parent 2d1b83e5
import json import json
import logging import logging
import re import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
import partial_json_parser import partial_json_parser
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field from pydantic import BaseModel
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [ ...@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
] ]
class Function(BaseModel):
"""Function Tool Template."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
class ToolCallItem(BaseModel): class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
...@@ -74,7 +75,22 @@ class StreamingParseResult: ...@@ -74,7 +75,22 @@ class StreamingParseResult:
self.calls = calls or [] self.calls = calls or []
class BaseFormatDetector: @dataclass
class StructureInfo:
begin: str
end: str
trigger: str
_GetInfoFunc = Callable[[str], StructureInfo]
"""
helper alias of function
ususally it is a function that takes a name string and returns a StructureInfo object,
which can be used to construct a structural_tag object
"""
class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental.""" """Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self): def __init__(self):
...@@ -90,26 +106,12 @@ class BaseFormatDetector: ...@@ -90,26 +106,12 @@ class BaseFormatDetector:
self.bot_token = "" self.bot_token = ""
self.eot_token = "" self.eot_token = ""
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]: def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = { tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
} }
if not isinstance(action, list): if not isinstance(action, list):
name = action.get("name") action = [action]
if not name or name not in tool_indices:
logger.warning(f"Model attempted to call undefined function: {name}")
return []
return [
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
action.get("parameters") or action.get("arguments", {}),
ensure_ascii=False,
),
)
]
results = [] results = []
for act in action: for act in action:
...@@ -125,12 +127,13 @@ class BaseFormatDetector: ...@@ -125,12 +127,13 @@ class BaseFormatDetector:
), ),
) )
) )
else:
logger.warning(f"Model attempted to call undefined function: {name}")
return results return results
def detect_and_parse( @abstractmethod
self, text: str, tools: List[Function] def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
) -> StreamingParseResult:
""" """
Parses the text in one go. Returns success=True if the format matches, otherwise False. Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further". Note that leftover_text here represents "content that this parser will not consume further".
...@@ -139,7 +142,7 @@ class BaseFormatDetector: ...@@ -139,7 +142,7 @@ class BaseFormatDetector:
return StreamingParseResult(calls=self.parse_base_json(action, tools)) return StreamingParseResult(calls=self.parse_base_json(action, tools))
def parse_streaming_increment( def parse_streaming_increment(
self, new_text: str, tools: List[Function] self, new_text: str, tools: List[Tool]
) -> StreamingParseResult: ) -> StreamingParseResult:
""" """
Streaming incremental parsing with tool validation. Streaming incremental parsing with tool validation.
...@@ -198,7 +201,7 @@ class BaseFormatDetector: ...@@ -198,7 +201,7 @@ class BaseFormatDetector:
obj["arguments"] = obj["parameters"] obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj) tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON: except MalformedJSON:
return StreamingParseResult() return StreamingParseResult()
if len(tool_call_arr) == 0: if len(tool_call_arr) == 0:
...@@ -304,6 +307,14 @@ class BaseFormatDetector: ...@@ -304,6 +307,14 @@ class BaseFormatDetector:
logger.error(f"Error in parse_streaming_increment: {e}") logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult() return StreamingParseResult()
@abstractmethod
def has_tool_call(self, text: str) -> bool:
raise NotImplementedError()
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
class Qwen25Detector(BaseFormatDetector): class Qwen25Detector(BaseFormatDetector):
""" """
...@@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector):
"""Check if the text contains a Qwen 2.5 format tool call.""" """Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text return self.bot_token in text
def detect_and_parse( def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
self, text: str, tools: List[Function]
) -> StreamingParseResult:
""" """
One-time parsing: Detects and parses tool calls in the provided text. One-time parsing: Detects and parses tool calls in the provided text.
...@@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector):
calls.extend(self.parse_base_json(match_result, tools)) calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls) return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<tool_call>{"name":"' + name + '", "arguments":',
end="}</tool_call>",
trigger="<tool_call>",
)
class MistralDetector(BaseFormatDetector): class MistralDetector(BaseFormatDetector):
""" """
...@@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector): ...@@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector):
else: else:
return "" return ""
def detect_and_parse( def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
self, text: str, tools: List[Function]
) -> StreamingParseResult:
""" """
One-time parsing: Detects and parses tool calls in the provided text. One-time parsing: Detects and parses tool calls in the provided text.
...@@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector): ...@@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector):
calls.extend(self.parse_base_json(match_result, tools)) calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls) return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
end="}]",
trigger="[TOOL_CALLS]",
)
class Llama32Detector(BaseFormatDetector): class Llama32Detector(BaseFormatDetector):
""" """
...@@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector): ...@@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector):
# prefix the output with the <|python_tag|> token # prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{") return "<|python_tag|>" in text or text.startswith("{")
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects.""" """Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"): if "<|python_tag|>" not in text and not text.startswith("{"):
return StreamingParseResult(normal_text=text, calls=[]) return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text: if "<|python_tag|>" in text:
_, action_text = text.split("<|python_tag|>") normal_text, action_text = text.split("<|python_tag|>")
else: else:
action_text = text normal_text, action_text = "", text
# Split by semicolon and process each part # Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()] json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
...@@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector): ...@@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector):
calls = self.parse_base_json(all_actions, tools) calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls) return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
end="}",
trigger="<|python_tag|>",
)
class MultiFormatParser: class MultiFormatParser:
def __init__(self, detectors: List[BaseFormatDetector]): def __init__(self, detectors: List[BaseFormatDetector]):
...@@ -458,7 +486,7 @@ class MultiFormatParser: ...@@ -458,7 +486,7 @@ class MultiFormatParser:
self.detectors = detectors self.detectors = detectors
def parse_once( def parse_once(
self, text: str, tools: List[Function] self, text: str, tools: List[Tool]
) -> Tuple[str, list[ToolCallItem]]: ) -> Tuple[str, list[ToolCallItem]]:
""" """
One-time parsing: Loop through detectors until there are no new matches or text is exhausted One-time parsing: Loop through detectors until there are no new matches or text is exhausted
...@@ -480,7 +508,7 @@ class MultiFormatParser: ...@@ -480,7 +508,7 @@ class MultiFormatParser:
return final_normal_text, final_calls return final_normal_text, final_calls
def parse_streaming_increment( def parse_streaming_increment(
self, new_text: str, tools: List[Function] self, new_text: str, tools: List[Tool]
) -> Tuple[str, list[ToolCallItem]]: ) -> Tuple[str, list[ToolCallItem]]:
""" """
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
...@@ -512,13 +540,13 @@ class FunctionCallParser: ...@@ -512,13 +540,13 @@ class FunctionCallParser:
and returns the resulting normal_text and calls to the upper layer (or SSE). and returns the resulting normal_text and calls to the upper layer (or SSE).
""" """
ToolCallParserEnum: Dict[str, BaseFormatDetector] = { ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
"llama3": Llama32Detector, "llama3": Llama32Detector,
"qwen25": Qwen25Detector, "qwen25": Qwen25Detector,
"mistral": MistralDetector, "mistral": MistralDetector,
} }
def __init__(self, tools: List[Function], tool_call_parser: str = None): def __init__(self, tools: List[Tool], tool_call_parser: str):
detectors = [] detectors = []
if tool_call_parser: if tool_call_parser:
detector_class = self.ToolCallParserEnum.get(tool_call_parser) detector_class = self.ToolCallParserEnum.get(tool_call_parser)
...@@ -563,3 +591,40 @@ class FunctionCallParser: ...@@ -563,3 +591,40 @@ class FunctionCallParser:
chunk_text, self.tools chunk_text, self.tools
) )
return normal_text, calls return normal_text, calls
def structure_infos(self) -> List[_GetInfoFunc]:
"""
Returns a list of structure_info functions for each detector
"""
return [
detector.structure_info() for detector in self.multi_format_parser.detectors
]
def get_structure_tag(self) -> StructuralTagResponseFormat:
tool_structures: List[StructuresResponseFormat] = list()
tool_trigger_set: Set[str] = set()
for wrapper in self.structure_infos():
for tool in self.tools:
function = tool.function
name = function.name
assert name is not None
info = wrapper(name)
# accept all if not strict, otherwise only accept the schema
schema = function.parameters if function.strict else {}
tool_structures.append(
StructuresResponseFormat(
begin=info.begin,
schema=schema, # type: ignore
end=info.end,
)
)
tool_trigger_set.add(info.trigger)
return StructuralTagResponseFormat(
type="structural_tag",
structures=tool_structures,
triggers=list(tool_trigger_set),
)
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import time import time
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List from typing import Any, Dict, List, Set
from fastapi import HTTPException, Request, UploadFile from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
...@@ -38,7 +38,7 @@ from sglang.srt.conversation import ( ...@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
generate_embedding_convs, generate_embedding_convs,
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import ( from sglang.srt.openai_api.protocol import (
BatchRequest, BatchRequest,
...@@ -915,6 +915,7 @@ def v1_chat_generate_request( ...@@ -915,6 +915,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings). # - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs). # - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput. # None skips any image processing in GenerateReqInput.
strict_tag = None
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
# Apply chat template and its stop strings. # Apply chat template and its stop strings.
tools = None tools = None
...@@ -929,6 +930,10 @@ def v1_chat_generate_request( ...@@ -929,6 +930,10 @@ def v1_chat_generate_request(
else: else:
tools = [item.function.model_dump() for item in request.tools] tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
strict_tag = parser.get_structure_tag()
if chat_template_name is None: if chat_template_name is None:
openai_compatible_messages = [] openai_compatible_messages = []
for message in request.messages: for message in request.messages:
...@@ -1036,6 +1041,22 @@ def v1_chat_generate_request( ...@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
sampling_params["structural_tag"] = convert_json_schema_to_str( sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True) request.response_format.model_dump(by_alias=True)
) )
if strict_tag is not None:
if (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
):
logger.warning(
"Constrained decoding is not compatible with tool calls."
)
else:
sampling_params["structural_tag"] = convert_json_schema_to_str(
strict_tag.model_dump(by_alias=True)
)
sampling_params_list.append(sampling_params) sampling_params_list.append(sampling_params)
image_data_list.append(image_data) image_data_list.append(image_data)
......
...@@ -287,6 +287,7 @@ class Function(BaseModel): ...@@ -287,6 +287,7 @@ class Function(BaseModel):
description: Optional[str] = Field(default=None, examples=[None]) description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None name: Optional[str] = None
parameters: Optional[object] = None parameters: Optional[object] = None
strict: bool = False
class Tool(BaseModel): class Tool(BaseModel):
......
...@@ -237,12 +237,61 @@ class TestOpenAIServerFunctionCalling(unittest.TestCase): ...@@ -237,12 +237,61 @@ class TestOpenAIServerFunctionCalling(unittest.TestCase):
self.assertIn("a", args_obj, "Missing parameter 'a'") self.assertIn("a", args_obj, "Missing parameter 'a'")
self.assertIn("b", args_obj, "Missing parameter 'b'") self.assertIn("b", args_obj, "Missing parameter 'b'")
self.assertEqual( self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5")
args_obj["a"], self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7")
5,
"Parameter a should be 5", def test_function_call_strict(self):
"""
Test: Whether the strict mode of function calling works as expected.
- When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "int",
"description": "First integer",
},
"int_b": {
"type": "int",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
}
]
messages = [
{"role": "user", "content": "Please compute 5 - 7, using your tool."}
]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
) )
self.assertEqual(args_obj["b"], 7, "Parameter b should be 7")
tool_calls = response.choices[0].message.tool_calls
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(function_name, "sub", "Function name should be 'sub'")
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment