Unverified Commit 65f18884 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

feat: Xml coder tool parser (#4415)

This commit adds a parser implementation for Xml style
tool calls, based currently on Qwen3 coder.

Followup work needs to be done to make it more generic
and parameterizable.
parent 6b432625
...@@ -2776,6 +2776,7 @@ dependencies = [ ...@@ -2776,6 +2776,7 @@ dependencies = [
"num-traits", "num-traits",
"openai-harmony", "openai-harmony",
"regex", "regex",
"rstest 0.25.0",
"rustpython-parser", "rustpython-parser",
"serde", "serde",
"serde_json", "serde_json",
...@@ -8761,6 +8762,18 @@ dependencies = [ ...@@ -8761,6 +8762,18 @@ dependencies = [
"rustc_version", "rustc_version",
] ]
[[package]]
name = "rstest"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d"
dependencies = [
"futures-timer",
"futures-util",
"rstest_macros 0.25.0",
"rustc_version",
]
[[package]] [[package]]
name = "rstest_macros" name = "rstest_macros"
version = "0.18.2" version = "0.18.2"
...@@ -8796,6 +8809,24 @@ dependencies = [ ...@@ -8796,6 +8809,24 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "rstest_macros"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746"
dependencies = [
"cfg-if 1.0.4",
"glob",
"proc-macro-crate",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn 2.0.110",
"unicode-ident",
]
[[package]] [[package]]
name = "rstest_reuse" name = "rstest_reuse"
version = "0.7.0" version = "0.7.0"
......
...@@ -38,3 +38,6 @@ openai-harmony = "0.0.3" ...@@ -38,3 +38,6 @@ openai-harmony = "0.0.3"
lazy_static = "1.5.0" lazy_static = "1.5.0"
rustpython-parser = "0.4.0" rustpython-parser = "0.4.0"
num-traits = "0.2" num-traits = "0.2"
[dev-dependencies]
rstest = "0.25"
...@@ -55,6 +55,7 @@ impl Default for JsonParserConfig { ...@@ -55,6 +55,7 @@ impl Default for JsonParserConfig {
} }
/// Configuration for parsing tool calls with different formats /// Configuration for parsing tool calls with different formats
// TODO(2ez4bz): refactor to allow other parser configs than `JsonParserConfig`.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig { pub struct ToolCallConfig {
/// The format type for tool calls /// The format type for tool calls
...@@ -192,4 +193,12 @@ impl ToolCallConfig { ...@@ -192,4 +193,12 @@ impl ToolCallConfig {
}, },
} }
} }
pub fn qwen3_coder() -> Self {
// <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
Self {
format: ToolCallParserType::Xml,
json: JsonParserConfig::default(), // Not used for qwen3_coder but kept for consistency.
}
}
} }
...@@ -10,6 +10,7 @@ pub mod response; ...@@ -10,6 +10,7 @@ pub mod response;
#[cfg(test)] #[cfg(test)]
pub mod tests; pub mod tests;
pub mod tools; pub mod tools;
pub mod xml;
// Re-export main types and functions for convenience // Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType}; pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
...@@ -22,3 +23,4 @@ pub use parsers::{ ...@@ -22,3 +23,4 @@ pub use parsers::{
pub use pythonic::try_tool_call_parse_pythonic; pub use pythonic::try_tool_call_parse_pythonic;
pub use response::{CalledFunction, ToolCallResponse, ToolCallType}; pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream}; pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
pub use xml::try_tool_call_parse_xml;
...@@ -14,6 +14,9 @@ use super::pythonic::{ ...@@ -14,6 +14,9 @@ use super::pythonic::{
try_tool_call_parse_pythonic, try_tool_call_parse_pythonic,
}; };
use super::response::ToolCallResponse; use super::response::ToolCallResponse;
use super::xml::{
detect_tool_call_start_xml, find_tool_call_end_position_xml, try_tool_call_parse_xml,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::OnceLock; use std::sync::OnceLock;
...@@ -32,6 +35,7 @@ pub fn get_tool_parser_map() -> &'static HashMap<&'static str, ToolCallConfig> { ...@@ -32,6 +35,7 @@ pub fn get_tool_parser_map() -> &'static HashMap<&'static str, ToolCallConfig> {
map.insert("harmony", ToolCallConfig::harmony()); map.insert("harmony", ToolCallConfig::harmony());
map.insert("deepseek_v3", ToolCallConfig::deepseek_v3()); map.insert("deepseek_v3", ToolCallConfig::deepseek_v3());
map.insert("deepseek_v3_1", ToolCallConfig::deepseek_v3_1()); map.insert("deepseek_v3_1", ToolCallConfig::deepseek_v3_1());
map.insert("qwen3_coder", ToolCallConfig::qwen3_coder());
map.insert("default", ToolCallConfig::default()); map.insert("default", ToolCallConfig::default());
map map
}) })
...@@ -64,7 +68,8 @@ pub async fn try_tool_call_parse( ...@@ -64,7 +68,8 @@ pub async fn try_tool_call_parse(
anyhow::bail!("Typescript parser not implemented"); anyhow::bail!("Typescript parser not implemented");
} }
ToolCallParserType::Xml => { ToolCallParserType::Xml => {
anyhow::bail!("Xml parser not implemented"); let (results, normal_content) = try_tool_call_parse_xml(message)?;
Ok((results, normal_content))
} }
} }
} }
...@@ -113,9 +118,7 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow:: ...@@ -113,9 +118,7 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::
ToolCallParserType::Typescript => { ToolCallParserType::Typescript => {
anyhow::bail!("Typescript parser not implemented"); anyhow::bail!("Typescript parser not implemented");
} }
ToolCallParserType::Xml => { ToolCallParserType::Xml => Ok(detect_tool_call_start_xml(chunk)),
anyhow::bail!("Xml parser not implemented");
}
}, },
None => anyhow::bail!( None => anyhow::bail!(
"Parser '{}' is not implemented. Available parsers: {:?}", "Parser '{}' is not implemented. Available parsers: {:?}",
...@@ -149,10 +152,7 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi ...@@ -149,10 +152,7 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi
// Typescript parser not implemented // Typescript parser not implemented
chunk.len() chunk.len()
} }
ToolCallParserType::Xml => { ToolCallParserType::Xml => find_tool_call_end_position_xml(chunk),
// Xml parser not implemented
chunk.len()
}
}, },
None => { None => {
// Unknown parser, return full content length // Unknown parser, return full content length
...@@ -188,6 +188,7 @@ mod tests { ...@@ -188,6 +188,7 @@ mod tests {
"pythonic", "pythonic",
"deepseek_v3", "deepseek_v3",
"deepseek_v3_1", "deepseek_v3_1",
"qwen3_coder",
]; ];
for parser in available_parsers { for parser in available_parsers {
assert!(parsers.contains(&parser)); assert!(parsers.contains(&parser));
...@@ -1681,13 +1682,13 @@ mod parallel_tool_calling_tests { ...@@ -1681,13 +1682,13 @@ mod parallel_tool_calling_tests {
validate_weather_tool_calls(&result, &[("Dallas", "TX"), ("Orlando", "FL")]); validate_weather_tool_calls(&result, &[("Dallas", "TX"), ("Orlando", "FL")]);
} }
// ============================================================================= // =================================================
// 2. QWEN3CODER TOOL PARSER FORMAT (XML-style tags) - Testing via hermes parser // 2. QWEN3CODER TOOL PARSER FORMAT (XML-style tags)
// ============================================================================= // =================================================
#[tokio::test] #[tokio::test]
async fn test_parallel_qwen3coder_format_two_cities() { async fn test_parallel_qwen3coder_format_two_cities() {
let _input = r#"<tool_call> let input = r#"<tool_call>
<function=get_current_weather> <function=get_current_weather>
<parameter=city> <parameter=city>
Dallas Dallas
...@@ -1714,12 +1715,7 @@ fahrenheit ...@@ -1714,12 +1715,7 @@ fahrenheit
</function> </function>
</tool_call>"#; </tool_call>"#;
// Note: This format would need a specialized parser, but for now we test with hermes let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
// which handles multiple <tool_call> tags
let input_hermes_format = r#"<tool_call>{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}</tool_call>
<tool_call>{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input_hermes_format, Some("hermes"))
.await .await
.unwrap(); .unwrap();
...@@ -2471,4 +2467,335 @@ mod detect_parser_tests { ...@@ -2471,4 +2467,335 @@ mod detect_parser_tests {
let result = detect_tool_call_start(text, Some("deepseek_v3_1")).unwrap(); let result = detect_tool_call_start(text, Some("deepseek_v3_1")).unwrap();
assert!(result); assert!(result);
} }
#[test]
fn test_e2e_detect_tool_call_start_xml() {
let text = r#"<tool_call><function=get_weather><parameter=city>Dallas</parameter></function></tool_call>"#;
let result = detect_tool_call_start(text, Some("qwen3_coder")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_xml_partial() {
let text = r#"<tool_c"#; // Partial start token
let result = detect_tool_call_start(text, Some("qwen3_coder")).unwrap();
assert!(result);
}
}
// Xml parser tests
#[cfg(test)]
mod xml_parser_tests {
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
(call.function.name, args)
}
#[tokio::test]
async fn test_qwen3_coder_simple_tool_call() {
let input = r#"<tool_call>
<function=execute_bash>
<parameter=command>
pwd && ls
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "execute_bash");
assert_eq!(args["command"], "pwd && ls");
}
#[tokio::test]
async fn test_qwen3_coder_multiple_parameters() {
let input = r#"<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["city"], "Dallas");
assert_eq!(args["state"], "TX");
assert_eq!(args["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_qwen3_coder_with_normal_text() {
let input = r#"I'll help you check the weather. <tool_call>
<function=get_current_weather>
<parameter=city>
San Francisco
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call> Let me get that information for you."#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(
content,
Some(
"I'll help you check the weather. Let me get that information for you."
.to_string()
)
);
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["city"], "San Francisco");
assert_eq!(args["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_qwen3_coder_parallel_tool_calls() {
let input = r#"<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name1, args1) = extract_name_and_args(result[0].clone());
assert_eq!(name1, "get_current_weather");
assert_eq!(args1["city"], "Dallas");
assert_eq!(args1["state"], "TX");
assert_eq!(args1["unit"], "fahrenheit");
let (name2, args2) = extract_name_and_args(result[1].clone());
assert_eq!(name2, "get_current_weather");
assert_eq!(args2["city"], "Orlando");
assert_eq!(args2["state"], "FL");
assert_eq!(args2["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_qwen3_coder_json_parameter_value() {
let input = r#"<tool_call>
<function=process_data>
<parameter=config>
{"timeout": 30, "retries": 3}
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "process_data");
assert!(args["config"].is_object());
assert_eq!(args["config"]["timeout"], 30);
assert_eq!(args["config"]["retries"], 3);
}
#[tokio::test]
async fn test_qwen3_coder_numeric_parameters() {
let input = r#"<tool_call>
<function=calculate>
<parameter=x>
42
</parameter>
<parameter=y>
3.15
</parameter>
<parameter=enabled>
true
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "calculate");
assert_eq!(args["x"], 42);
assert_eq!(args["y"], 3.15);
assert_eq!(args["enabled"], true);
}
#[tokio::test]
async fn test_qwen3_coder_no_tool_calls() {
let input = "This is just normal text without any tool calls.";
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(result.len(), 0);
assert_eq!(content, Some(input.to_string()));
}
#[tokio::test]
async fn test_qwen3_coder_compact_format() {
let input = r#"<tool_call><function=search><parameter=query>rust programming</parameter><parameter=limit>10</parameter></function></tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "search");
assert_eq!(args["query"], "rust programming");
assert_eq!(args["limit"], 10);
}
#[tokio::test]
async fn test_qwen3_coder_html_entities() {
let input = r#"<tool_call>
<function=print_message>
<parameter=text>
&lt;div&gt;Hello &amp; Welcome&lt;/div&gt;
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "print_message");
assert_eq!(args["text"], "<div>Hello & Welcome</div>");
}
#[tokio::test]
async fn test_qwen3_coder_three_parallel_calls() {
let input = r#"<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Seattle
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
let cities = ["Dallas", "Orlando", "Seattle"];
for (i, expected_city) in cities.iter().enumerate() {
let (name, args) = extract_name_and_args(result[i].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["city"], *expected_city);
}
}
#[tokio::test]
async fn test_qwen3_coder_mixed_tool_types() {
let input = r#"<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=web_search>
<parameter=query>
weather forecasting
</parameter>
<parameter=max_results>
5
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name1, args1) = extract_name_and_args(result[0].clone());
assert_eq!(name1, "get_current_weather");
assert_eq!(args1["city"], "Dallas");
assert_eq!(args1["unit"], "fahrenheit");
let (name2, args2) = extract_name_and_args(result[1].clone());
assert_eq!(name2, "web_search");
assert_eq!(args2["query"], "weather forecasting");
assert_eq!(args2["max_results"], 5);
}
#[tokio::test]
async fn test_qwen3_coder_array_parameter_value() {
let input = r#"<tool_call>
<function=process_list>
<parameter=items>
[1, 2, 3, 4, 5]
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
.await
.unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "process_list");
assert!(args["items"].is_array());
assert_eq!(args["items"], serde_json::json!([1, 2, 3, 4, 5]));
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod parser;
pub use super::response;
pub use parser::{
detect_tool_call_start_xml, find_tool_call_end_position_xml, try_tool_call_parse_xml,
};
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// Reference implementation:
// https://github.com/sgl-project/sglang/blob/44da737770e4bcd9bfa27751f0a0751c9b5c06e1/python/sglang/srt/function_call/qwen3_coder_detector.py
use std::collections::HashMap;
use std::sync::OnceLock;
use regex::Regex;
use uuid::Uuid;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
/// Check if a chunk contains the start of a xml-style tool call.
/// Format: <tool_call><function=name><parameter=foo>...</parameter></function></tool_call>
// TODO(2ez4bz): Add a parser config struct that allows parameterizing:
// * the tool call start / end tokens
// * the function start / end tokens
// * the parameter start / end tokens
pub fn detect_tool_call_start_xml(chunk: &str) -> bool {
// Check for complete or partial start token.
let start_token = "<tool_call>";
// Check if we have the complete start token.
if chunk.contains(start_token) {
return true;
}
// Check for partial match at the end of the chunk (for streaming).
for i in 1..start_token.len() {
if chunk.ends_with(&start_token[..i]) {
return true;
}
}
false
}
/// Find the end position of a Qwen3Coder tool call.
/// Returns the position after </tool_call> or the length of the chunk if not found.
pub fn find_tool_call_end_position_xml(chunk: &str) -> usize {
let end_token = "</tool_call>";
if let Some(pos) = chunk.find(end_token) {
pos + end_token.len()
} else {
chunk.len()
}
}
/// Try to parse Qwen3Coder formatted tool calls from a message.
/// Format: <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
/// Returns (parsed_tool_calls, normal_text_content)
pub fn try_tool_call_parse_xml(
message: &str,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let (normal_text, tool_calls) = extract_tool_calls(message)?;
let normal_content = if normal_text.is_empty() {
Some("".to_string())
} else {
Some(normal_text)
};
Ok((tool_calls, normal_content))
}
/// Extract tool calls and normal text from message.
fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec<ToolCallResponse>)> {
let mut normal_parts = Vec::new();
let mut calls = Vec::new();
let mut cursor = 0;
let start_token = "<tool_call>";
let end_token = "</tool_call>";
while cursor < text.len() {
// Find next tool call start.
if let Some(start_pos) = text[cursor..].find(start_token) {
let abs_start = cursor + start_pos;
// Add text before tool call to normal parts.
normal_parts.push(&text[cursor..abs_start]);
// Find the corresponding end token.
if let Some(end_pos) = text[abs_start..].find(end_token) {
let abs_end = abs_start + end_pos + end_token.len();
let block = &text[abs_start..abs_end];
// Parse this tool call block.
if let Ok(mut parsed_calls) = parse_tool_call_block(block) {
calls.append(&mut parsed_calls);
}
cursor = abs_end;
} else {
// No end token found -> treat the rest as normal text.
normal_parts.push(&text[abs_start..]);
break;
}
} else {
// No more tool calls.
normal_parts.push(&text[cursor..]);
break;
}
}
let normal_text = normal_parts.join("").trim().to_string();
Ok((normal_text, calls))
}
/// Parse a single tool call block
/// Format: <tool_call><function=name><parameter=key>value</parameter>...</function></tool_call>
fn parse_tool_call_block(block: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
static FUNCTION_REGEX: OnceLock<Regex> = OnceLock::new();
static PARAMETER_REGEX: OnceLock<Regex> = OnceLock::new();
let function_regex = FUNCTION_REGEX.get_or_init(|| {
// Match <function=name>content</function> or partial <function=name>content
// (?s) makes . match newlines
Regex::new(r"(?s)<function=([^>]+)>(.*?)(?:</function>|$)").unwrap()
});
let parameter_regex = PARAMETER_REGEX.get_or_init(|| {
// Match <parameter=key>value</parameter> or partial <parameter=key>value
// (?s) makes . match newlines
Regex::new(r"(?s)<parameter=([^>]+)>(.*?)(?:</parameter>|$)").unwrap()
});
let mut results = Vec::new();
// Find all function blocks.
for func_cap in function_regex.captures_iter(block) {
let function_name = func_cap.get(1).map(|m| m.as_str().trim()).unwrap_or("");
let function_body = func_cap.get(2).map(|m| m.as_str()).unwrap_or("");
if function_name.is_empty() {
continue;
}
// Parse parameters from the function body.
let mut parameters: HashMap<String, serde_json::Value> = HashMap::new();
for param_cap in parameter_regex.captures_iter(function_body) {
let param_name = param_cap.get(1).map(|m| m.as_str().trim()).unwrap_or("");
let param_value = param_cap.get(2).map(|m| m.as_str()).unwrap_or("");
if !param_name.is_empty() {
let parsed_value = safe_parse_value(param_value);
parameters.insert(param_name.to_string(), parsed_value);
}
}
// Create tool call response.
let arguments_json = serde_json::to_string(&parameters)?;
let tool_call = ToolCallResponse {
id: format!("call-{}", Uuid::new_v4()),
tp: ToolCallType::Function,
function: CalledFunction {
name: function_name.to_string(),
arguments: arguments_json,
},
};
results.push(tool_call);
}
Ok(results)
}
/// Safely parse a value - tries JSON, then falls back to string.
/// Mimics SGLang's `_safe_val` function in spirit.
fn safe_parse_value(raw: &str) -> serde_json::Value {
// HTML unescape
let unescaped = html_unescape(raw.trim());
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&unescaped) {
return value;
}
if let Ok(num) = unescaped.parse::<i64>() {
return serde_json::Value::Number(num.into());
}
if let Ok(num) = unescaped.parse::<f64>()
&& let Some(num_val) = serde_json::Number::from_f64(num)
{
return serde_json::Value::Number(num_val);
}
match unescaped.to_lowercase().as_str() {
"true" => return serde_json::Value::Bool(true),
"false" => return serde_json::Value::Bool(false),
"null" | "none" => return serde_json::Value::Null,
_ => {}
}
// Default to string, stripping newlines from start and end.
serde_json::Value::String(unescaped.trim_matches('\n').to_string())
}
/// Simple HTML unescape for common entities.
fn html_unescape(s: &str) -> String {
s.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&amp;", "&")
.replace("&quot;", "\"")
.replace("&#x27;", "'")
.replace("&#39;", "'")
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[test]
fn test_detect_tool_call_start() {
assert!(detect_tool_call_start_xml("<tool_call>"));
assert!(detect_tool_call_start_xml("text <tool_call>"));
assert!(detect_tool_call_start_xml("<tool_c")); // Partial match
assert!(detect_tool_call_start_xml("<")); // Partial match
assert!(!detect_tool_call_start_xml("no tool call here"));
assert!(!detect_tool_call_start_xml("toolcall"));
}
#[test]
fn test_find_tool_call_end_position() {
let text = "<tool_call><function=test></function></tool_call>more text";
let pos = find_tool_call_end_position_xml(text);
assert_eq!(pos, 49); // Position after </tool_call>
assert_eq!(&text[pos..], "more text");
let text_no_end = "<tool_call><function=test>";
let pos = find_tool_call_end_position_xml(text_no_end);
assert_eq!(pos, text_no_end.len());
}
#[rstest]
#[case(r#"{"key": "value"}"#, serde_json::json!({"key": "value"}), "JSON object")]
#[case(r#"[1, 2, 3]"#, serde_json::json!([1, 2, 3]), "JSON array")]
#[case("42", serde_json::json!(42), "integer")]
#[case("3.15", serde_json::json!(3.15), "float")]
#[case("true", serde_json::json!(true), "boolean true")]
#[case("false", serde_json::json!(false), "boolean false")]
#[case("null", serde_json::json!(null), "null")]
#[case("hello", serde_json::json!("hello"), "unquoted string")]
#[case(" text ", serde_json::json!("text"), "trimmed string")]
fn test_safe_parse_value(
#[case] input: &str,
#[case] expected: serde_json::Value,
#[case] _description: &str,
) {
assert_eq!(safe_parse_value(input), expected);
}
#[rstest]
#[case("&lt;div&gt;", "<div>", "HTML tags")]
#[case("a &amp; b", "a & b", "ampersand")]
#[case("&quot;quoted&quot;", "\"quoted\"", "quotes")]
fn test_html_unescape(#[case] input: &str, #[case] expected: &str, #[case] _description: &str) {
assert_eq!(html_unescape(input), expected);
}
#[test]
fn test_parse_simple_tool_call() {
let input = r#"<tool_call>
<function=execute_bash>
<parameter=command>
pwd && ls
</parameter>
</function>
</tool_call>"#;
let (calls, normal) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash");
assert_eq!(normal, Some("".to_string()));
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["command"], "pwd && ls");
}
#[test]
fn test_parse_multiple_parameters() {
let input = r#"<tool_call>
<function=get_weather>
<parameter=city>
San Francisco
</parameter>
<parameter=state>
CA
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["city"], "San Francisco");
assert_eq!(args["state"], "CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_parse_with_normal_text() {
let input = r#"I'll help you with that. <tool_call>
<function=get_weather>
<parameter=city>
Dallas
</parameter>
</function>
</tool_call> Let me check that for you."#;
let (calls, normal) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(
normal,
Some("I'll help you with that. Let me check that for you.".to_string())
);
}
#[test]
fn test_parse_multiple_tool_calls() {
let input = r#"<tool_call>
<function=get_weather>
<parameter=city>
Dallas
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_weather>
<parameter=city>
Orlando
</parameter>
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(calls[1].function.name, "get_weather");
let args0: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
let args1: serde_json::Value = serde_json::from_str(&calls[1].function.arguments).unwrap();
assert_eq!(args0["city"], "Dallas");
assert_eq!(args1["city"], "Orlando");
}
#[test]
fn test_parse_json_parameter_value() {
let input = r#"<tool_call>
<function=process_data>
<parameter=config>
{"setting": "value", "count": 42}
</parameter>
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert!(args["config"].is_object());
assert_eq!(args["config"]["setting"], "value");
assert_eq!(args["config"]["count"], 42);
}
#[test]
fn test_parse_no_tool_calls() {
let input = "This is just normal text without any tool calls.";
let (calls, normal) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 0);
assert_eq!(normal, Some(input.to_string()));
}
#[test]
fn test_parse_malformed_tool_call() {
let input = r#"<tool_call>
<function=incomplete>
<parameter=test>
value
</tool_call>"#;
// Should handle gracefully - might parse or return empty
let result = try_tool_call_parse_xml(input);
assert!(result.is_ok());
}
#[test]
fn test_parse_missing_parameter_closing_tag() {
let input = r#"<tool_call>
<function=execute_bash>
<parameter=command>
ls -la
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["command"], "ls -la");
}
#[test]
fn test_parse_missing_function_closing_tag() {
let input = r#"<tool_call>
<function=get_weather>
<parameter=city>
Boston
</parameter>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["city"], "Boston");
}
#[test]
fn test_parse_missing_both_closing_tags() {
let input = r#"<tool_call>
<function=run_query>
<parameter=sql>
SELECT * FROM users
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "run_query");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
// This matches the original SGLang python implementation.
assert_eq!(args["sql"], "SELECT * FROM users\n</tool_call>");
}
#[test]
fn test_parse_multiple_parameters_missing_closing_tags() {
let input = r#"<tool_call>
<function=search>
<parameter=query>
rust programming
<parameter=limit>
10
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
// This matches the original SGLang python implementation.
assert_eq!(args["query"], "rust programming\n<parameter=limit>\n10");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment