"vscode:/vscode.git/clone" did not exist on "4a408d5e076451df6eed2fa44646062eec86c611"
Unverified Commit 44da7377 authored by soaringk's avatar soaringk Committed by GitHub
Browse files

[fix] Handle escaped characters in GLM tool call parser to prevent double serialization (#12456)

parent fb9582c4
......@@ -24,13 +24,23 @@ def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
def parse_arguments(json_value):
try:
try:
parsed_value = json.loads(json_value)
except:
parsed_value = ast.literal_eval(json_value)
parsed_value = json.loads(json_value)
return parsed_value, True
except:
return json_value, False
# If that fails, try wrapping it to unescape JSON characters
try:
# Wrap the value as a JSON string field
wrapped = json.loads('{"tmp": "' + json_value + '"}')
# parse the unescaped value
parsed_value = json.loads(wrapped["tmp"])
return parsed_value, True
except:
# Final fallback to ast.literal_eval
try:
parsed_value = ast.literal_eval(json_value)
return parsed_value, True
except:
return json_value, False
class Glm4MoeDetector(BaseFormatDetector):
......@@ -45,8 +55,13 @@ class Glm4MoeDetector(BaseFormatDetector):
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
self.func_call_regex = r"<tool_call>.*?</tool_call>"
self.func_detail_regex = r"<tool_call>([^\n]*)\n(.*)</tool_call>"
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
self.func_detail_regex = re.compile(
r"<tool_call>(.*?)(?:\\n|\n)(.*)</tool_call>", re.DOTALL
)
self.func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
re.DOTALL,
)
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
......@@ -69,14 +84,10 @@ class Glm4MoeDetector(BaseFormatDetector):
try:
for match_result in match_result_list:
# Get function name
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
func_detail = self.func_detail_regex.search(match_result)
func_name = func_detail.group(1)
func_args = func_detail.group(2)
pairs = re.findall(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
func_args,
re.DOTALL,
)
pairs = self.func_arg_regex.findall(func_args)
arguments = {}
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
......
......@@ -2191,6 +2191,109 @@ class TestGlm4MoeDetector(unittest.TestCase):
)
self.assertEqual(self.detector._buffer, "")
def test_array_argument_with_escaped_json(self):
"""Test that array arguments with escaped JSON are properly handled without double-escaping."""
# Add a tool with array parameter
tools_with_array = [
Tool(
type="function",
function=Function(
name="todo_write",
description="Write todos",
parameters={
"type": "object",
"properties": {
"todos": {
"type": "array",
"description": "The updated todo list",
}
},
"required": ["todos"],
},
),
),
]
def check_params(result):
self.assertEqual(1, len(result.calls))
self.assertEqual("todo_write", result.calls[0].name)
params = json.loads(result.calls[0].parameters)
self.assertIsInstance(params["todos"], list)
self.assertEqual(4, len(params["todos"]))
self.assertEqual("1", params["todos"][0]["id"])
self.assertEqual(
"Check for hard-coded issues in the backend code",
params["todos"][0]["task"],
)
self.assertEqual("in_progress", params["todos"][0]["status"])
self.assertEqual("2", params["todos"][1]["id"])
self.assertEqual(
"Check for hard-coded issues in the frontend code",
params["todos"][1]["task"],
)
self.assertEqual("pending", params["todos"][1]["status"])
self.assertEqual("3", params["todos"][2]["id"])
self.assertEqual(
"Check for code violating the Single Responsibility Principle",
params["todos"][2]["task"],
)
self.assertEqual("pending", params["todos"][2]["status"])
self.assertEqual("4", params["todos"][3]["id"])
self.assertEqual(
"Generate a rectification proposal report", params["todos"][3]["task"]
)
self.assertEqual("pending", params["todos"][3]["status"])
# Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML
result = self.detector.detect_and_parse(
"""<tool_call>todo_write\n<arg_key>todos</arg_key>\n<arg_value>[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}]</arg_value>
</tool_call>""",
tools_with_array,
)
check_params(result)
result = self.detector.detect_and_parse(
r"""<tool_call>todo_write\n<arg_key>todos</arg_key>\n<arg_value>[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}]</arg_value>
</tool_call>""",
tools_with_array,
)
check_params(result)
def check_single_todos(tool_result, expected):
self.assertEqual(1, len(tool_result.calls))
self.assertEqual("todo_write", tool_result.calls[0].name)
params = json.loads(tool_result.calls[0].parameters)
self.assertIsInstance(params["todos"], list)
self.assertEqual(1, len(params["todos"]))
self.assertEqual("1", params["todos"][0]["id"])
self.assertEqual(expected, params["todos"][0]["task"])
self.assertEqual("pending", params["todos"][0]["status"])
# Test with escaped JSON containing backslashes in content (e.g., Windows paths)
expected_path = r"Check file at C:\Users\test.txt"
result = self.detector.detect_and_parse(
"""<tool_call>todo_write\n<arg_key>todos</arg_key>\n<arg_value>[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]</arg_value></tool_call>""",
tools_with_array,
)
check_single_todos(result, expected_path)
result = self.detector.detect_and_parse(
r"""<tool_call>todo_write\n<arg_key>todos</arg_key>\n<arg_value>[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]</arg_value></tool_call>""",
tools_with_array,
)
check_single_todos(result, expected_path)
# Should contain literal \n, not actual newline
expected_output = r"Print \n to see newline"
result = self.detector.detect_and_parse(
"""<tool_call>todo_write\n<arg_key>todos</arg_key>\n<arg_value>[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]</arg_value></tool_call>""",
tools_with_array,
)
check_single_todos(result, expected_output)
result = self.detector.detect_and_parse(
r"""<tool_call>todo_write\n<arg_key>todos</arg_key>\n<arg_value>[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]</arg_value></tool_call>""",
tools_with_array,
)
check_single_todos(result, expected_output)
class TestJsonArrayParser(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment