Unverified Commit f18b068f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (#6784)

parent 4fac524b
...@@ -42,31 +42,41 @@ class Llama32Detector(BaseFormatDetector): ...@@ -42,31 +42,41 @@ class Llama32Detector(BaseFormatDetector):
return StreamingParseResult(normal_text=text, calls=[]) return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text: if "<|python_tag|>" in text:
normal_text, action_text = text.split("<|python_tag|>") normal_text, action_text = text.split("<|python_tag|>", maxsplit=1)
else: else:
normal_text, action_text = "", text normal_text, action_text = "", text
# Split by semicolon and process each part decoder = json.JSONDecoder()
json_parts = [ idx = 0
part.strip() safe_idx = idx # the index of the last valid JSON object
for part in action_text.split(self.tool_call_separator)
if part.strip()
]
all_actions = [] all_actions = []
for part in json_parts: action_text_len = len(action_text)
while idx < action_text_len:
try: try:
# Parse each individual JSON object obj, end = decoder.raw_decode(action_text[idx:])
action = json.loads(part) all_actions.append(obj)
all_actions.append(action) idx += end + len(self.tool_call_separator)
safe_idx = idx
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}") # Find where next `{"name"` appears and try again
logger.warning(f"JSON parse error: {str(e)}") logger.warning(
f"Failed to parse JSON part: {action_text[idx:]}, JSON parse error: {str(e)}"
)
next_obj_start = action_text.find('{"name":', idx + 1)
if next_obj_start == -1:
break
idx = next_obj_start
continue continue
calls = []
# Only process if we found valid JSON objects # Only process if we found valid JSON objects
if all_actions: calls = self.parse_base_json(all_actions, tools) if all_actions else []
calls = self.parse_base_json(all_actions, tools) # Use safe_idx to avoid idx containing the last part of an invalid JSON object
return StreamingParseResult(normal_text=normal_text, calls=calls) trailing_text = (
action_text[safe_idx:].strip() if safe_idx < action_text_len else ""
)
return StreamingParseResult(
normal_text=normal_text + trailing_text, calls=calls
)
def structure_info(self) -> _GetInfoFunc: def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo( return lambda name: StructureInfo(
......
...@@ -824,5 +824,101 @@ class TestBaseFormatDetector(unittest.TestCase): ...@@ -824,5 +824,101 @@ class TestBaseFormatDetector(unittest.TestCase):
) )
class TestLlama32Detector(unittest.TestCase):
def setUp(self):
"""Set up test tools and detector for Mistral format testing."""
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
}
},
"required": ["city"],
},
),
),
Tool(
type="function",
function=Function(
name="get_tourist_attractions",
description="Get tourist attractions",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
}
},
"required": ["city"],
},
),
),
]
self.detector = Llama32Detector()
def test_single_json(self):
text = '{"name": "get_weather", "parameters": {"city": "Paris"}}'
result = self.detector.detect_and_parse(text, self.tools)
assert len(result.calls) == 1
assert result.calls[0].name == "get_weather"
assert result.normal_text == ""
def test_multiple_json_with_separator(self):
text = (
'<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};'
'{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}'
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 2)
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
self.assertEqual(result.normal_text, "")
def test_multiple_json_with_separator_customized(self):
text = (
'<|python_tag|>{"name": "get_weather", "parameters": {}}'
'<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}'
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 2)
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
self.assertEqual(result.normal_text, "")
def test_json_with_trailing_text(self):
text = '{"name": "get_weather", "parameters": {}} Some follow-up text'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertIn("follow-up", result.normal_text)
def test_invalid_then_valid_json(self):
text = (
'{"name": "get_weather", "parameters": {' # malformed
'{"name": "get_weather", "parameters": {}}'
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
def test_plain_text_only(self):
text = "This is just plain explanation text."
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(result.calls, [])
self.assertEqual(result.normal_text, text)
def test_with_python_tag_prefix(self):
text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertTrue(result.normal_text.strip().startswith("Some intro."))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment