Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f18b068f
Unverified
Commit
f18b068f
authored
May 30, 2025
by
Chang Su
Committed by
GitHub
May 30, 2025
Browse files
feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (#6784)
parent
4fac524b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
123 additions
and
17 deletions
+123
-17
python/sglang/srt/function_call/llama32_detector.py
python/sglang/srt/function_call/llama32_detector.py
+27
-17
test/srt/test_function_call_parser.py
test/srt/test_function_call_parser.py
+96
-0
No files found.
python/sglang/srt/function_call/llama32_detector.py
View file @
f18b068f
...
@@ -42,31 +42,41 @@ class Llama32Detector(BaseFormatDetector):
...
@@ -42,31 +42,41 @@ class Llama32Detector(BaseFormatDetector):
return
StreamingParseResult
(
normal_text
=
text
,
calls
=
[])
return
StreamingParseResult
(
normal_text
=
text
,
calls
=
[])
if
"<|python_tag|>"
in
text
:
if
"<|python_tag|>"
in
text
:
normal_text
,
action_text
=
text
.
split
(
"<|python_tag|>"
)
normal_text
,
action_text
=
text
.
split
(
"<|python_tag|>"
,
maxsplit
=
1
)
else
:
else
:
normal_text
,
action_text
=
""
,
text
normal_text
,
action_text
=
""
,
text
# Split by semicolon and process each part
decoder
=
json
.
JSONDecoder
()
json_parts
=
[
idx
=
0
part
.
strip
()
safe_idx
=
idx
# the index of the last valid JSON object
for
part
in
action_text
.
split
(
self
.
tool_call_separator
)
if
part
.
strip
()
]
all_actions
=
[]
all_actions
=
[]
for
part
in
json_parts
:
action_text_len
=
len
(
action_text
)
while
idx
<
action_text_len
:
try
:
try
:
# Parse each individual JSON object
obj
,
end
=
decoder
.
raw_decode
(
action_text
[
idx
:])
action
=
json
.
loads
(
part
)
all_actions
.
append
(
obj
)
all_actions
.
append
(
action
)
idx
+=
end
+
len
(
self
.
tool_call_separator
)
safe_idx
=
idx
except
json
.
JSONDecodeError
as
e
:
except
json
.
JSONDecodeError
as
e
:
logger
.
warning
(
f
"Failed to parse JSON part:
{
part
}
"
)
# Find where next `{"name"` appears and try again
logger
.
warning
(
f
"JSON parse error:
{
str
(
e
)
}
"
)
logger
.
warning
(
f
"Failed to parse JSON part:
{
action_text
[
idx
:]
}
, JSON parse error:
{
str
(
e
)
}
"
)
next_obj_start
=
action_text
.
find
(
'{"name":'
,
idx
+
1
)
if
next_obj_start
==
-
1
:
break
idx
=
next_obj_start
continue
continue
calls
=
[]
# Only process if we found valid JSON objects
# Only process if we found valid JSON objects
if
all_actions
:
calls
=
self
.
parse_base_json
(
all_actions
,
tools
)
if
all_actions
else
[]
calls
=
self
.
parse_base_json
(
all_actions
,
tools
)
# Use safe_idx to avoid idx containing the last part of an invalid JSON object
return
StreamingParseResult
(
normal_text
=
normal_text
,
calls
=
calls
)
trailing_text
=
(
action_text
[
safe_idx
:].
strip
()
if
safe_idx
<
action_text_len
else
""
)
return
StreamingParseResult
(
normal_text
=
normal_text
+
trailing_text
,
calls
=
calls
)
def
structure_info
(
self
)
->
_GetInfoFunc
:
def
structure_info
(
self
)
->
_GetInfoFunc
:
return
lambda
name
:
StructureInfo
(
return
lambda
name
:
StructureInfo
(
...
...
test/srt/test_function_call_parser.py
View file @
f18b068f
...
@@ -824,5 +824,101 @@ class TestBaseFormatDetector(unittest.TestCase):
...
@@ -824,5 +824,101 @@ class TestBaseFormatDetector(unittest.TestCase):
)
)
class
TestLlama32Detector
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""Set up test tools and detector for Mistral format testing."""
self
.
tools
=
[
Tool
(
type
=
"function"
,
function
=
Function
(
name
=
"get_weather"
,
description
=
"Get weather information"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"City name"
,
}
},
"required"
:
[
"city"
],
},
),
),
Tool
(
type
=
"function"
,
function
=
Function
(
name
=
"get_tourist_attractions"
,
description
=
"Get tourist attractions"
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
,
"description"
:
"City name"
,
}
},
"required"
:
[
"city"
],
},
),
),
]
self
.
detector
=
Llama32Detector
()
def
test_single_json
(
self
):
text
=
'{"name": "get_weather", "parameters": {"city": "Paris"}}'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
assert
len
(
result
.
calls
)
==
1
assert
result
.
calls
[
0
].
name
==
"get_weather"
assert
result
.
normal_text
==
""
def
test_multiple_json_with_separator
(
self
):
text
=
(
'<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};'
'{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}'
)
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
2
)
self
.
assertEqual
(
result
.
calls
[
1
].
name
,
"get_tourist_attractions"
)
self
.
assertEqual
(
result
.
normal_text
,
""
)
def
test_multiple_json_with_separator_customized
(
self
):
text
=
(
'<|python_tag|>{"name": "get_weather", "parameters": {}}'
'<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}'
)
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
2
)
self
.
assertEqual
(
result
.
calls
[
1
].
name
,
"get_tourist_attractions"
)
self
.
assertEqual
(
result
.
normal_text
,
""
)
def
test_json_with_trailing_text
(
self
):
text
=
'{"name": "get_weather", "parameters": {}} Some follow-up text'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
1
)
self
.
assertIn
(
"follow-up"
,
result
.
normal_text
)
def
test_invalid_then_valid_json
(
self
):
text
=
(
'{"name": "get_weather", "parameters": {'
# malformed
'{"name": "get_weather", "parameters": {}}'
)
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
1
)
self
.
assertEqual
(
result
.
calls
[
0
].
name
,
"get_weather"
)
def
test_plain_text_only
(
self
):
text
=
"This is just plain explanation text."
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
result
.
calls
,
[])
self
.
assertEqual
(
result
.
normal_text
,
text
)
def
test_with_python_tag_prefix
(
self
):
text
=
'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}'
result
=
self
.
detector
.
detect_and_parse
(
text
,
self
.
tools
)
self
.
assertEqual
(
len
(
result
.
calls
),
1
)
self
.
assertTrue
(
result
.
normal_text
.
strip
().
startswith
(
"Some intro."
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment