Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ac3fae84
Unverified
Commit
ac3fae84
authored
Mar 25, 2025
by
DarkSharpness
Committed by
GitHub
Mar 24, 2025
Browse files
[Feature] Support "strict" in function calling (#4310)
parent
2d1b83e5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
188 additions
and
52 deletions
+188
-52
python/sglang/srt/function_call_parser.py
python/sglang/srt/function_call_parser.py
+110
-45
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+23
-2
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+1
-0
test/srt/test_function_calling.py
test/srt/test_function_calling.py
+54
-5
No files found.
python/sglang/srt/function_call_parser.py
View file @
ac3fae84
import
json
import
logging
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
json
import
JSONDecodeError
,
JSONDecoder
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tupl
e
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Typ
e
import
partial_json_parser
from
partial_json_parser.core.exceptions
import
MalformedJSON
from
partial_json_parser.core.options
import
Allow
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
from
sglang.srt.openai_api.protocol
import
(
StructuralTagResponseFormat
,
StructuresResponseFormat
,
Tool
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
]
class
Function
(
BaseModel
):
"""Function Tool Template."""
description
:
Optional
[
str
]
=
Field
(
default
=
None
,
examples
=
[
None
])
name
:
Optional
[
str
]
=
None
parameters
:
Optional
[
object
]
=
None
class
ToolCallItem
(
BaseModel
):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
...
...
@@ -74,7 +75,22 @@ class StreamingParseResult:
self
.
calls
=
calls
or
[]
class
BaseFormatDetector
:
@
dataclass
class
StructureInfo
:
begin
:
str
end
:
str
trigger
:
str
_GetInfoFunc
=
Callable
[[
str
],
StructureInfo
]
"""
helper alias of function
ususally it is a function that takes a name string and returns a StructureInfo object,
which can be used to construct a structural_tag object
"""
class
BaseFormatDetector
(
ABC
):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def
__init__
(
self
):
...
...
@@ -90,26 +106,12 @@ class BaseFormatDetector:
self
.
bot_token
=
""
self
.
eot_token
=
""
def
parse_base_json
(
self
,
action
:
Any
,
tools
:
List
[
Function
])
->
List
[
ToolCallItem
]:
def
parse_base_json
(
self
,
action
:
Any
,
tools
:
List
[
Tool
])
->
List
[
ToolCallItem
]:
tool_indices
=
{
tool
.
function
.
name
:
i
for
i
,
tool
in
enumerate
(
tools
)
if
tool
.
function
.
name
}
if
not
isinstance
(
action
,
list
):
name
=
action
.
get
(
"name"
)
if
not
name
or
name
not
in
tool_indices
:
logger
.
warning
(
f
"Model attempted to call undefined function:
{
name
}
"
)
return
[]
return
[
ToolCallItem
(
tool_index
=
tool_indices
[
name
],
name
=
name
,
parameters
=
json
.
dumps
(
action
.
get
(
"parameters"
)
or
action
.
get
(
"arguments"
,
{}),
ensure_ascii
=
False
,
),
)
]
action
=
[
action
]
results
=
[]
for
act
in
action
:
...
...
@@ -125,12 +127,13 @@ class BaseFormatDetector:
),
)
)
else
:
logger
.
warning
(
f
"Model attempted to call undefined function:
{
name
}
"
)
return
results
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Function
]
)
->
StreamingParseResult
:
@
abstractmethod
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
...
...
@@ -139,7 +142,7 @@ class BaseFormatDetector:
return
StreamingParseResult
(
calls
=
self
.
parse_base_json
(
action
,
tools
))
def
parse_streaming_increment
(
self
,
new_text
:
str
,
tools
:
List
[
Function
]
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
StreamingParseResult
:
"""
Streaming incremental parsing with tool validation.
...
...
@@ -198,7 +201,7 @@ class BaseFormatDetector:
obj
[
"arguments"
]
=
obj
[
"parameters"
]
tool_call_arr
.
append
(
obj
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
except
MalformedJSON
:
return
StreamingParseResult
()
if
len
(
tool_call_arr
)
==
0
:
...
...
@@ -304,6 +307,14 @@ class BaseFormatDetector:
logger
.
error
(
f
"Error in parse_streaming_increment:
{
e
}
"
)
return
StreamingParseResult
()
@
abstractmethod
def
has_tool_call
(
self
,
text
:
str
)
->
bool
:
raise
NotImplementedError
()
@
abstractmethod
def
structure_info
(
self
)
->
_GetInfoFunc
:
raise
NotImplementedError
()
class
Qwen25Detector
(
BaseFormatDetector
):
"""
...
...
@@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector):
"""Check if the text contains a Qwen 2.5 format tool call."""
return
self
.
bot_token
in
text
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Function
]
)
->
StreamingParseResult
:
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
One-time parsing: Detects and parses tool calls in the provided text.
...
...
@@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector):
calls
.
extend
(
self
.
parse_base_json
(
match_result
,
tools
))
return
StreamingParseResult
(
normal_text
=
normal_text
,
calls
=
calls
)
def
structure_info
(
self
)
->
_GetInfoFunc
:
return
lambda
name
:
StructureInfo
(
begin
=
'<tool_call>{"name":"'
+
name
+
'", "arguments":'
,
end
=
"}</tool_call>"
,
trigger
=
"<tool_call>"
,
)
class
MistralDetector
(
BaseFormatDetector
):
"""
...
...
@@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector):
else
:
return
""
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Function
]
)
->
StreamingParseResult
:
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""
One-time parsing: Detects and parses tool calls in the provided text.
...
...
@@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector):
calls
.
extend
(
self
.
parse_base_json
(
match_result
,
tools
))
return
StreamingParseResult
(
normal_text
=
normal_text
,
calls
=
calls
)
def
structure_info
(
self
)
->
_GetInfoFunc
:
return
lambda
name
:
StructureInfo
(
begin
=
'[TOOL_CALLS] [{"name":"'
+
name
+
'", "arguments":'
,
end
=
"}]"
,
trigger
=
"[TOOL_CALLS]"
,
)
class
Llama32Detector
(
BaseFormatDetector
):
"""
...
...
@@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector):
# prefix the output with the <|python_tag|> token
return
"<|python_tag|>"
in
text
or
text
.
startswith
(
"{"
)
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Function
])
->
List
[
ToolCallItem
]
:
def
detect_and_parse
(
self
,
text
:
str
,
tools
:
List
[
Tool
])
->
StreamingParseResult
:
"""Parse function calls from text, handling multiple JSON objects."""
if
"<|python_tag|>"
not
in
text
and
not
text
.
startswith
(
"{"
):
return
StreamingParseResult
(
normal_text
=
text
,
calls
=
[])
if
"<|python_tag|>"
in
text
:
_
,
action_text
=
text
.
split
(
"<|python_tag|>"
)
normal_text
,
action_text
=
text
.
split
(
"<|python_tag|>"
)
else
:
action_text
=
text
normal_text
,
action_text
=
""
,
text
# Split by semicolon and process each part
json_parts
=
[
part
.
strip
()
for
part
in
action_text
.
split
(
";"
)
if
part
.
strip
()]
...
...
@@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector):
calls
=
self
.
parse_base_json
(
all_actions
,
tools
)
return
StreamingParseResult
(
normal_text
=
normal_text
,
calls
=
calls
)
def
structure_info
(
self
)
->
_GetInfoFunc
:
return
lambda
name
:
StructureInfo
(
begin
=
'<|python_tag|>{"name":"'
+
name
+
'", "arguments":'
,
end
=
"}"
,
trigger
=
"<|python_tag|>"
,
)
class
MultiFormatParser
:
def
__init__
(
self
,
detectors
:
List
[
BaseFormatDetector
]):
...
...
@@ -458,7 +486,7 @@ class MultiFormatParser:
self
.
detectors
=
detectors
def
parse_once
(
self
,
text
:
str
,
tools
:
List
[
Function
]
self
,
text
:
str
,
tools
:
List
[
Tool
]
)
->
Tuple
[
str
,
list
[
ToolCallItem
]]:
"""
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
...
...
@@ -480,7 +508,7 @@ class MultiFormatParser:
return
final_normal_text
,
final_calls
def
parse_streaming_increment
(
self
,
new_text
:
str
,
tools
:
List
[
Function
]
self
,
new_text
:
str
,
tools
:
List
[
Tool
]
)
->
Tuple
[
str
,
list
[
ToolCallItem
]]:
"""
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
...
...
@@ -512,13 +540,13 @@ class FunctionCallParser:
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
ToolCallParserEnum
:
Dict
[
str
,
BaseFormatDetector
]
=
{
ToolCallParserEnum
:
Dict
[
str
,
Type
[
BaseFormatDetector
]
]
=
{
"llama3"
:
Llama32Detector
,
"qwen25"
:
Qwen25Detector
,
"mistral"
:
MistralDetector
,
}
def
__init__
(
self
,
tools
:
List
[
Function
],
tool_call_parser
:
str
=
None
):
def
__init__
(
self
,
tools
:
List
[
Tool
],
tool_call_parser
:
str
):
detectors
=
[]
if
tool_call_parser
:
detector_class
=
self
.
ToolCallParserEnum
.
get
(
tool_call_parser
)
...
...
@@ -563,3 +591,40 @@ class FunctionCallParser:
chunk_text
,
self
.
tools
)
return
normal_text
,
calls
def
structure_infos
(
self
)
->
List
[
_GetInfoFunc
]:
"""
Returns a list of structure_info functions for each detector
"""
return
[
detector
.
structure_info
()
for
detector
in
self
.
multi_format_parser
.
detectors
]
def
get_structure_tag
(
self
)
->
StructuralTagResponseFormat
:
tool_structures
:
List
[
StructuresResponseFormat
]
=
list
()
tool_trigger_set
:
Set
[
str
]
=
set
()
for
wrapper
in
self
.
structure_infos
():
for
tool
in
self
.
tools
:
function
=
tool
.
function
name
=
function
.
name
assert
name
is
not
None
info
=
wrapper
(
name
)
# accept all if not strict, otherwise only accept the schema
schema
=
function
.
parameters
if
function
.
strict
else
{}
tool_structures
.
append
(
StructuresResponseFormat
(
begin
=
info
.
begin
,
schema
=
schema
,
# type: ignore
end
=
info
.
end
,
)
)
tool_trigger_set
.
add
(
info
.
trigger
)
return
StructuralTagResponseFormat
(
type
=
"structural_tag"
,
structures
=
tool_structures
,
triggers
=
list
(
tool_trigger_set
),
)
python/sglang/srt/openai_api/adapter.py
View file @
ac3fae84
...
...
@@ -20,7 +20,7 @@ import os
import
time
import
uuid
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Set
from
fastapi
import
HTTPException
,
Request
,
UploadFile
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
...
...
@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
generate_embedding_convs
,
register_conv_template
,
)
from
sglang.srt.function_call_parser
import
TOOLS_TAG_LIST
,
FunctionCallParser
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
...
...
@@ -915,6 +915,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
strict_tag
=
None
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
tools
=
None
...
...
@@ -929,6 +930,10 @@ def v1_chat_generate_request(
else
:
tools
=
[
item
.
function
.
model_dump
()
for
item
in
request
.
tools
]
tool_call_parser
=
tokenizer_manager
.
server_args
.
tool_call_parser
parser
=
FunctionCallParser
(
request
.
tools
,
tool_call_parser
)
strict_tag
=
parser
.
get_structure_tag
()
if
chat_template_name
is
None
:
openai_compatible_messages
=
[]
for
message
in
request
.
messages
:
...
...
@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
sampling_params
[
"structural_tag"
]
=
convert_json_schema_to_str
(
request
.
response_format
.
model_dump
(
by_alias
=
True
)
)
if
strict_tag
is
not
None
:
if
(
sampling_params
.
get
(
"regex"
)
or
sampling_params
.
get
(
"ebnf"
)
or
sampling_params
.
get
(
"structural_tag"
)
or
sampling_params
.
get
(
"json_schema"
)
):
logger
.
warning
(
"Constrained decoding is not compatible with tool calls."
)
else
:
sampling_params
[
"structural_tag"
]
=
convert_json_schema_to_str
(
strict_tag
.
model_dump
(
by_alias
=
True
)
)
sampling_params_list
.
append
(
sampling_params
)
image_data_list
.
append
(
image_data
)
...
...
python/sglang/srt/openai_api/protocol.py
View file @
ac3fae84
...
...
@@ -287,6 +287,7 @@ class Function(BaseModel):
description
:
Optional
[
str
]
=
Field
(
default
=
None
,
examples
=
[
None
])
name
:
Optional
[
str
]
=
None
parameters
:
Optional
[
object
]
=
None
strict
:
bool
=
False
class
Tool
(
BaseModel
):
...
...
test/srt/test_function_calling.py
View file @
ac3fae84
...
...
@@ -237,12 +237,61 @@ class TestOpenAIServerFunctionCalling(unittest.TestCase):
self
.
assertIn
(
"a"
,
args_obj
,
"Missing parameter 'a'"
)
self
.
assertIn
(
"b"
,
args_obj
,
"Missing parameter 'b'"
)
self
.
assertEqual
(
args_obj
[
"a"
],
5
,
"Parameter a should be 5"
,
self
.
assertEqual
(
str
(
args_obj
[
"a"
]),
"5"
,
"Parameter a should be 5"
)
self
.
assertEqual
(
str
(
args_obj
[
"b"
]),
"7"
,
"Parameter b should be 7"
)
def
test_function_call_strict
(
self
):
"""
Test: Whether the strict mode of function calling works as expected.
- When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"sub"
,
"description"
:
"Compute the difference of two integers"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"int_a"
:
{
"type"
:
"int"
,
"description"
:
"First integer"
,
},
"int_b"
:
{
"type"
:
"int"
,
"description"
:
"Second integer"
,
},
},
"required"
:
[
"int_a"
,
"int_b"
],
},
"strict"
:
True
,
},
}
]
messages
=
[
{
"role"
:
"user"
,
"content"
:
"Please compute 5 - 7, using your tool."
}
]
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
messages
,
temperature
=
0.8
,
top_p
=
0.8
,
stream
=
False
,
tools
=
tools
,
)
self
.
assertEqual
(
args_obj
[
"b"
],
7
,
"Parameter b should be 7"
)
tool_calls
=
response
.
choices
[
0
].
message
.
tool_calls
function_name
=
tool_calls
[
0
].
function
.
name
arguments
=
tool_calls
[
0
].
function
.
arguments
args_obj
=
json
.
loads
(
arguments
)
self
.
assertEqual
(
function_name
,
"sub"
,
"Function name should be 'sub'"
)
self
.
assertEqual
(
str
(
args_obj
[
"int_a"
]),
"5"
,
"Parameter int_a should be 5"
)
self
.
assertEqual
(
str
(
args_obj
[
"int_b"
]),
"7"
,
"Parameter int_b should be 7"
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment