// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // Reference implementation: // https://github.com/sgl-project/sglang/blob/44da737770e4bcd9bfa27751f0a0751c9b5c06e1/python/sglang/srt/function_call/qwen3_coder_detector.py use std::collections::HashMap; use std::sync::OnceLock; use regex::Regex; use uuid::Uuid; use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; /// Check if a chunk contains the start of a xml-style tool call. /// Format: ... // TODO(2ez4bz): Add a parser config struct that allows parameterizing: // * the tool call start / end tokens // * the function start / end tokens // * the parameter start / end tokens pub fn detect_tool_call_start_xml(chunk: &str) -> bool { // Check for complete or partial start token. let start_token = ""; // Check if we have the complete start token. if chunk.contains(start_token) { return true; } // Check for partial match at the end of the chunk (for streaming). for i in 1..start_token.len() { if chunk.ends_with(&start_token[..i]) { return true; } } false } /// Find the end position of a Qwen3Coder tool call. /// Returns the position after or the length of the chunk if not found. pub fn find_tool_call_end_position_xml(chunk: &str) -> usize { let end_token = ""; if let Some(pos) = chunk.find(end_token) { pos + end_token.len() } else { chunk.len() } } /// Try to parse Qwen3Coder formatted tool calls from a message. /// Format: value /// Returns (parsed_tool_calls, normal_text_content) pub fn try_tool_call_parse_xml( message: &str, ) -> anyhow::Result<(Vec, Option)> { let (normal_text, tool_calls) = extract_tool_calls(message)?; let normal_content = if normal_text.is_empty() { Some("".to_string()) } else { Some(normal_text) }; Ok((tool_calls, normal_content)) } /// Extract tool calls and normal text from message. fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec)> { let mut normal_parts = Vec::new(); let mut calls = Vec::new(); let mut cursor = 0; let start_token = ""; let end_token = ""; while cursor < text.len() { // Find next tool call start. if let Some(start_pos) = text[cursor..].find(start_token) { let abs_start = cursor + start_pos; // Add text before tool call to normal parts. normal_parts.push(&text[cursor..abs_start]); // Find the corresponding end token. if let Some(end_pos) = text[abs_start..].find(end_token) { let abs_end = abs_start + end_pos + end_token.len(); let block = &text[abs_start..abs_end]; // Parse this tool call block. if let Ok(mut parsed_calls) = parse_tool_call_block(block) { calls.append(&mut parsed_calls); } cursor = abs_end; } else { // No end token found -> treat the rest as normal text. normal_parts.push(&text[abs_start..]); break; } } else { // No more tool calls. normal_parts.push(&text[cursor..]); break; } } let normal_text = normal_parts.join("").trim().to_string(); Ok((normal_text, calls)) } /// Parse a single tool call block /// Format: value... fn parse_tool_call_block(block: &str) -> anyhow::Result> { static FUNCTION_REGEX: OnceLock = OnceLock::new(); static PARAMETER_REGEX: OnceLock = OnceLock::new(); let function_regex = FUNCTION_REGEX.get_or_init(|| { // Match content or partial content // (?s) makes . match newlines Regex::new(r"(?s)]+)>(.*?)(?:|$)").unwrap() }); let parameter_regex = PARAMETER_REGEX.get_or_init(|| { // Match value or partial value // (?s) makes . match newlines Regex::new(r"(?s)]+)>(.*?)(?:|$)").unwrap() }); let mut results = Vec::new(); // Find all function blocks. for func_cap in function_regex.captures_iter(block) { let function_name = func_cap.get(1).map(|m| m.as_str().trim()).unwrap_or(""); let function_body = func_cap.get(2).map(|m| m.as_str()).unwrap_or(""); if function_name.is_empty() { continue; } // Parse parameters from the function body. let mut parameters: HashMap = HashMap::new(); for param_cap in parameter_regex.captures_iter(function_body) { let param_name = param_cap.get(1).map(|m| m.as_str().trim()).unwrap_or(""); let param_value = param_cap.get(2).map(|m| m.as_str()).unwrap_or(""); if !param_name.is_empty() { let parsed_value = safe_parse_value(param_value); parameters.insert(param_name.to_string(), parsed_value); } } // Create tool call response. let arguments_json = serde_json::to_string(¶meters)?; let tool_call = ToolCallResponse { id: format!("call-{}", Uuid::new_v4()), tp: ToolCallType::Function, function: CalledFunction { name: function_name.to_string(), arguments: arguments_json, }, }; results.push(tool_call); } Ok(results) } /// Safely parse a value - tries JSON, then falls back to string. /// Mimics SGLang's `_safe_val` function in spirit. fn safe_parse_value(raw: &str) -> serde_json::Value { // HTML unescape let unescaped = html_unescape(raw.trim()); if let Ok(value) = serde_json::from_str::(&unescaped) { return value; } if let Ok(num) = unescaped.parse::() { return serde_json::Value::Number(num.into()); } if let Ok(num) = unescaped.parse::() && let Some(num_val) = serde_json::Number::from_f64(num) { return serde_json::Value::Number(num_val); } match unescaped.to_lowercase().as_str() { "true" => return serde_json::Value::Bool(true), "false" => return serde_json::Value::Bool(false), "null" | "none" => return serde_json::Value::Null, _ => {} } // Default to string, stripping newlines from start and end. serde_json::Value::String(unescaped.trim_matches('\n').to_string()) } /// Simple HTML unescape for common entities. fn html_unescape(s: &str) -> String { s.replace("<", "<") .replace(">", ">") .replace("&", "&") .replace(""", "\"") .replace("'", "'") .replace("'", "'") } #[cfg(test)] mod tests { use super::*; use rstest::rstest; #[test] fn test_detect_tool_call_start() { assert!(detect_tool_call_start_xml("")); assert!(detect_tool_call_start_xml("text ")); assert!(detect_tool_call_start_xml(" assert_eq!(&text[pos..], "more text"); let text_no_end = ""; let pos = find_tool_call_end_position_xml(text_no_end); assert_eq!(pos, text_no_end.len()); } #[rstest] #[case(r#"{"key": "value"}"#, serde_json::json!({"key": "value"}), "JSON object")] #[case(r#"[1, 2, 3]"#, serde_json::json!([1, 2, 3]), "JSON array")] #[case("42", serde_json::json!(42), "integer")] #[case("3.15", serde_json::json!(3.15), "float")] #[case("true", serde_json::json!(true), "boolean true")] #[case("false", serde_json::json!(false), "boolean false")] #[case("null", serde_json::json!(null), "null")] #[case("hello", serde_json::json!("hello"), "unquoted string")] #[case(" text ", serde_json::json!("text"), "trimmed string")] fn test_safe_parse_value( #[case] input: &str, #[case] expected: serde_json::Value, #[case] _description: &str, ) { assert_eq!(safe_parse_value(input), expected); } #[rstest] #[case("<div>", "
", "HTML tags")] #[case("a & b", "a & b", "ampersand")] #[case(""quoted"", "\"quoted\"", "quotes")] fn test_html_unescape(#[case] input: &str, #[case] expected: &str, #[case] _description: &str) { assert_eq!(html_unescape(input), expected); } #[test] fn test_parse_simple_tool_call() { let input = r#" pwd && ls "#; let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "execute_bash"); assert_eq!(normal, Some("".to_string())); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); assert_eq!(args["command"], "pwd && ls"); } #[test] fn test_parse_multiple_parameters() { let input = r#" San Francisco CA fahrenheit "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "get_weather"); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); assert_eq!(args["city"], "San Francisco"); assert_eq!(args["state"], "CA"); assert_eq!(args["unit"], "fahrenheit"); } #[test] fn test_parse_with_normal_text() { let input = r#"I'll help you with that. Dallas Let me check that for you."#; let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "get_weather"); assert_eq!( normal, Some("I'll help you with that. Let me check that for you.".to_string()) ); } #[test] fn test_parse_multiple_tool_calls() { let input = r#" Dallas Orlando "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 2); assert_eq!(calls[0].function.name, "get_weather"); assert_eq!(calls[1].function.name, "get_weather"); let args0: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); let args1: serde_json::Value = serde_json::from_str(&calls[1].function.arguments).unwrap(); assert_eq!(args0["city"], "Dallas"); assert_eq!(args1["city"], "Orlando"); } #[test] fn test_parse_json_parameter_value() { let input = r#" {"setting": "value", "count": 42} "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); assert!(args["config"].is_object()); assert_eq!(args["config"]["setting"], "value"); assert_eq!(args["config"]["count"], 42); } #[test] fn test_parse_no_tool_calls() { let input = "This is just normal text without any tool calls."; let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 0); assert_eq!(normal, Some(input.to_string())); } #[test] fn test_parse_malformed_tool_call() { let input = r#" value "#; // Should handle gracefully - might parse or return empty let result = try_tool_call_parse_xml(input); assert!(result.is_ok()); } #[test] fn test_parse_missing_parameter_closing_tag() { let input = r#" ls -la "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "execute_bash"); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); assert_eq!(args["command"], "ls -la"); } #[test] fn test_parse_missing_function_closing_tag() { let input = r#" Boston "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "get_weather"); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); assert_eq!(args["city"], "Boston"); } #[test] fn test_parse_missing_both_closing_tags() { let input = r#" SELECT * FROM users "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "run_query"); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); // This matches the original SGLang python implementation. assert_eq!(args["sql"], "SELECT * FROM users\n"); } #[test] fn test_parse_multiple_parameters_missing_closing_tags() { let input = r#" rust programming 10 "#; let (calls, _) = try_tool_call_parse_xml(input).unwrap(); assert_eq!(calls.len(), 1); assert_eq!(calls[0].function.name, "search"); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); // This matches the original SGLang python implementation. assert_eq!(args["query"], "rust programming\n\n10"); } }