/// Pythonic format parser for tool calls /// /// Handles Python function call syntax within square brackets: /// ```text /// [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] /// ``` /// /// This format is used by Llama models and uses Python literals /// rather than JSON for arguments. use async_trait::async_trait; use num_traits::ToPrimitive; use regex::Regex; use rustpython_parser::ast::{Constant, Expr, Mod, UnaryOp}; use rustpython_parser::{parse, Mode}; use serde_json::{Map, Number, Value}; use std::sync::OnceLock; use crate::tool_parser::{ errors::{ToolParserError, ToolParserResult}, state::ParseState, traits::ToolParser, types::{FunctionCall, StreamResult, ToolCall}, }; static PYTHONIC_BLOCK_REGEX: OnceLock = OnceLock::new(); /// Lazily compiled regex that locates pythonic tool call blocks. fn pythonic_block_regex() -> &'static Regex { PYTHONIC_BLOCK_REGEX.get_or_init(|| { // Matches one or more function calls inside a list. The `(?s)` flag allows // newlines inside argument lists while keeping the pattern anchored to // identifiers followed by parentheses, preventing plain lists like // `[1, 2, 3]` from matching. Regex::new(r"(?s)\[\s*[A-Za-z_]\w*\s*\(.*?\)\s*(?:,\s*[A-Za-z_]\w*\s*\(.*?\)\s*)*\]") .expect("pythonic tool call regex must compile") }) } /// Parser for Pythonic tool call format #[derive(Default)] pub struct PythonicParser; impl PythonicParser { /// Create a new Pythonic parser pub fn new() -> Self { Self } /// Extract the first pythonic tool call block and return it along with the /// surrounding "normal" content. fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> { pythonic_block_regex().find(text).map(|mat| { let block = mat.as_str().to_string(); let normal = format!("{}{}", &text[..mat.start()], &text[mat.end()..]); (block, normal) }) } /// Strip special tokens that Llama models might output fn strip_special_tokens(text: &str) -> String { text.replace("<|python_start|>", "") .replace("<|python_end|>", "") } fn parse_tool_call_block(&self, block: &str) -> ToolParserResult> { let expr = parse_python_expression(block)?; match expr { Expr::List(list_expr) => list_expr .elts .into_iter() .enumerate() .map(|(idx, call_expr)| build_tool_call(call_expr, idx)) .collect(), _ => Err(ToolParserError::ParsingFailed( "Expected a list of function calls in pythonic tool call".to_string(), )), } } } #[async_trait] impl ToolParser for PythonicParser { async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec)> { let cleaned = Self::strip_special_tokens(text); if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) { let calls = self.parse_tool_call_block(&tool_calls_text)?; Ok((normal_text, calls)) } else { Ok((text.to_string(), vec![])) } } async fn parse_incremental( &self, chunk: &str, state: &mut ParseState, ) -> ToolParserResult { state.buffer.push_str(chunk); let cleaned = Self::strip_special_tokens(&state.buffer); if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) { if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) { if let Some(tool) = tools.into_iter().next() { state.buffer.clear(); return Ok(StreamResult::ToolComplete(tool)); } } } Ok(StreamResult::Incomplete) } fn detect_format(&self, text: &str) -> bool { let cleaned = Self::strip_special_tokens(text); if pythonic_block_regex().is_match(&cleaned) { return true; } let trimmed = cleaned.trim(); let Some(open_idx) = trimmed.find('[') else { return false; }; let after_bracket = trimmed[open_idx + 1..].trim_start(); let mut chars = after_bracket.char_indices(); let Some((_, first_char)) = chars.next() else { return false; }; if !(first_char.is_ascii_alphabetic() || first_char == '_') { return false; } let mut ident_len = first_char.len_utf8(); for (idx, ch) in chars { if ch.is_alphanumeric() || ch == '_' { ident_len = idx + ch.len_utf8(); } else { break; } } let remaining = after_bracket[ident_len..].trim_start(); remaining.starts_with('(') } } fn parse_python_expression(source: &str) -> ToolParserResult { let module = parse(source, Mode::Expression, "") .map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?; match module { Mod::Expression(expr_mod) => Ok(*expr_mod.body), _ => Err(ToolParserError::ParsingFailed( "Expected a Python expression".to_string(), )), } } fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult { match expr { Expr::Call(call_expr) => { if !call_expr.args.is_empty() { return Err(ToolParserError::ParsingFailed( "Positional arguments are not supported in pythonic tool calls".to_string(), )); } let function_name = match *call_expr.func { Expr::Name(name_expr) => name_expr.id.to_string(), _ => { return Err(ToolParserError::ParsingFailed( "Unsupported function reference in pythonic tool call".to_string(), )) } }; let mut arguments_map = Map::with_capacity(call_expr.keywords.len()); for keyword in call_expr.keywords { let arg_name = keyword.arg.ok_or_else(|| { ToolParserError::ParsingFailed( "pythonic tool calls do not support **kwargs".to_string(), ) })?; let value_json = expression_to_json(&keyword.value)?; arguments_map.insert(arg_name.to_string(), value_json); } let arguments_json = Value::Object(arguments_map); let arguments_string = serde_json::to_string(&arguments_json)?; Ok(ToolCall { id: format!("call-{}", index + 1), r#type: "function".to_string(), function: FunctionCall { name: function_name, arguments: arguments_string, }, }) } _ => Err(ToolParserError::ParsingFailed( "Expected function calls inside pythonic tool call list".to_string(), )), } } fn expression_to_json(expr: &Expr) -> ToolParserResult { match expr { Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value), Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array), Expr::Tuple(tuple_expr) => collect_sequence(&tuple_expr.elts).map(Value::Array), Expr::Dict(dict_expr) => { collect_dict(&dict_expr.keys, &dict_expr.values).map(Value::Object) } Expr::UnaryOp(unary_expr) => match unary_expr.op { UnaryOp::USub => match unary_expr.operand.as_ref() { Expr::Constant(const_expr) => negate_constant(&const_expr.value), _ => Err(ToolParserError::ParsingFailed( "Unsupported unary operand in pythonic tool call".to_string(), )), }, UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()), _ => Err(ToolParserError::ParsingFailed(format!( "Unsupported unary operator in pythonic tool call: {:?}", unary_expr.op ))), }, Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())), _ => Err(ToolParserError::ParsingFailed(format!( "Unsupported expression in pythonic tool call: {:?}", expr ))), } } fn constant_to_json(constant: &Constant) -> ToolParserResult { match constant { Constant::None => Ok(Value::Null), Constant::Bool(b) => Ok(Value::Bool(*b)), Constant::Int(value) => Ok(integer_constant_to_value(value, false)), Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| { ToolParserError::ParsingFailed( "Invalid float literal in pythonic tool call".to_string(), ) }), Constant::Str(s) => Ok(Value::String(s.clone())), Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())), Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array), Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed( "Unsupported literal in pythonic tool call".to_string(), )), } } fn negate_constant(constant: &Constant) -> ToolParserResult { match constant { Constant::Int(value) => Ok(integer_constant_to_value(value, true)), Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| { ToolParserError::ParsingFailed( "Invalid float literal in pythonic tool call".to_string(), ) }), _ => Err(ToolParserError::ParsingFailed( "Unsupported unary operand in pythonic tool call".to_string(), )), } } fn value_to_key_string(value: Value) -> ToolParserResult { match value { Value::String(s) => Ok(s), Value::Number(num) => Ok(num.to_string()), Value::Bool(b) => Ok(b.to_string()), Value::Null => Ok("null".to_string()), other => Err(ToolParserError::ParsingFailed(format!( "Unsupported key type in pythonic tool call: {:?}", other ))), } } fn collect_sequence(elements: &[Expr]) -> ToolParserResult> { elements.iter().map(expression_to_json).collect() } fn collect_dict(keys: &[Option], values: &[Expr]) -> ToolParserResult> { let mut map = Map::with_capacity(keys.len()); for (key_expr, value_expr) in keys.iter().zip(values.iter()) { let key_expr = key_expr.as_ref().ok_or_else(|| { ToolParserError::ParsingFailed( "pythonic tool calls do not support **kwargs".to_string(), ) })?; let key_value = expression_to_json(key_expr)?; let key = value_to_key_string(key_value)?; let value_json = expression_to_json(value_expr)?; map.insert(key, value_json); } Ok(map) } fn constant_tuple_to_array(values: &[Constant]) -> ToolParserResult> { values.iter().map(constant_to_json).collect() } fn integer_constant_to_value(value: &T, negate: bool) -> Value where T: ToPrimitive + std::fmt::Display, { if let Some(mut i) = value.to_i64() { if negate { i = -i; } return Value::Number(Number::from(i)); } if negate { if let Some(u) = value.to_u64() { if u <= i64::MAX as u64 { return Value::Number(Number::from(-(u as i64))); } return Value::String(format!("-{}", value)); } Value::String(format!("-{}", value)) } else if let Some(u) = value.to_u64() { Value::Number(Number::from(u)) } else { Value::String(value.to_string()) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_single_function_call() { let parser = PythonicParser::new(); let input = r#"[search_web(query="Rust programming", max_results=5)]"#; let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); assert_eq!(tools.len(), 1); assert_eq!(tools[0].function.name, "search_web"); let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["query"], "Rust programming"); assert_eq!(args["max_results"], 5); } #[tokio::test] async fn test_multiple_function_calls() { let parser = PythonicParser::new(); let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#; let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); assert_eq!(tools.len(), 2); assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[1].function.name, "search"); } #[tokio::test] async fn test_python_literals() { let parser = PythonicParser::new(); let input = r#"[test(flag=True, disabled=False, optional=None)]"#; let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); assert_eq!(tools.len(), 1); let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); assert_eq!(args["flag"], true); assert_eq!(args["disabled"], false); assert!(args["optional"].is_null()); } #[tokio::test] async fn test_strip_special_tokens() { let parser = PythonicParser::new(); let input = "<|python_start|>[call(arg=1)]<|python_end|>"; assert!(parser.detect_format(input)); let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); assert_eq!(tools.len(), 1); } #[tokio::test] async fn test_detect_format() { let parser = PythonicParser::new(); assert!(parser.detect_format("[foo(bar=1)]")); assert!(!parser.detect_format("No python here")); } #[tokio::test] async fn test_parse_incremental() { let parser = PythonicParser::new(); let mut state = ParseState::new(); let chunk1 = "[call(arg="; let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap(); assert!(matches!(result1, StreamResult::Incomplete)); let chunk2 = "1)]"; let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap(); match result2 { StreamResult::ToolComplete(tool) => { assert_eq!(tool.function.name, "call"); } other => panic!("Expected ToolComplete, got {:?}", other), } } }