use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use std::collections::HashMap;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// Step3 format parser for tool calls
///
/// Handles the Step3 specific format with steptml XML:
/// `<|tool_calls_begin|><|tool_call_begin|>function<|tool_sep|>{v}<|tool_call_end|><|tool_calls_end|>`
///
/// Features:
/// - Unicode token delimiters
/// - StepTML XML format for invocations
/// - Support for multiple sequential tool calls
pub struct Step3Parser {
/// Regex for extracting tool call blocks
tool_call_extractor: Regex,
/// Regex for extracting steptml invocations
invoke_extractor: Regex,
/// Regex for extracting parameters
param_extractor: Regex,
/// Buffer for accumulating chunks
buffer: String,
/// Token configuration
bot_token: &'static str,
eot_token: &'static str,
tool_call_begin: &'static str,
tool_call_end: &'static str,
tool_sep: &'static str,
/// Streaming state variables (mirrors Python's Step3Detector)
in_tool_block: bool,
tool_block_finished: bool,
current_function_name: String,
current_parameters: serde_json::Map,
in_tool_call: bool,
function_name_sent: bool,
/// Standard state machine fields
prev_tool_call_arr: Vec,
current_tool_id: i32,
streamed_args_for_tool: Vec,
}
impl Step3Parser {
/// Create a new Step3 parser
pub fn new() -> Self {
// Pattern for individual tool calls
let tool_call_pattern = r"(?s)<|tool_call_begin|>.*?<|tool_call_end|>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
// Pattern for steptml invocations
let invoke_pattern = r#"(?s)(.+?)"#;
let invoke_extractor = Regex::new(invoke_pattern).expect("Valid regex pattern");
// Pattern for steptml parameters - using non-greedy match for values to handle < characters
let param_pattern = r#"(?s)(.+?)"#;
let param_extractor = Regex::new(param_pattern).expect("Valid regex pattern");
Self {
tool_call_extractor,
invoke_extractor,
param_extractor,
buffer: String::new(),
bot_token: "<|tool_calls_begin|>",
eot_token: "<|tool_calls_end|>",
tool_call_begin: "<|tool_call_begin|>",
tool_call_end: "<|tool_call_end|>",
tool_sep: "<|tool_sep|>",
// Streaming state variables
in_tool_block: false,
tool_block_finished: false,
current_function_name: String::new(),
current_parameters: serde_json::Map::new(),
in_tool_call: false,
function_name_sent: false,
// Standard state machine fields
prev_tool_call_arr: Vec::new(),
current_tool_id: -1,
streamed_args_for_tool: Vec::new(),
}
}
/// Reset streaming state for the next tool call
fn reset_streaming_state(&mut self) {
self.in_tool_call = false;
self.function_name_sent = false;
self.current_function_name.clear();
self.current_parameters.clear();
}
/// Parse partial tool call for streaming scenarios (mirrors Python's _parse_partial_tool_call)
fn parse_partial_tool_call(
&mut self,
tool_indices: &HashMap,
) -> ToolParserResult {
let mut calls = Vec::new();
// Check if we have tool_sep (means we're past the type declaration)
if !self.buffer.contains(self.tool_sep) {
return Ok(StreamingParseResult {
normal_text: String::new(),
calls,
});
}
// Clone the buffer to avoid borrow conflicts
let buffer_clone = self.buffer.clone();
let parts: Vec<&str> = buffer_clone.splitn(2, self.tool_sep).collect();
if parts.len() != 2 {
return Ok(StreamingParseResult {
normal_text: String::new(),
calls,
});
}
let type_part = parts[0].trim();
let invoke_part = parts[1];
// Check if it's a function type
if type_part != "function" {
// Invalid tool type, skip this tool call
self.reset_streaming_state();
return Ok(StreamingParseResult {
normal_text: String::new(),
calls,
});
}
// Try to extract function name if not sent yet
if !self.function_name_sent {
if let Some(captures) = self.invoke_extractor.captures(invoke_part) {
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
// Validate function name
if tool_indices.contains_key(func_name) {
self.current_function_name = func_name.to_string();
self.function_name_sent = true;
// Initialize tool tracking
if self.current_tool_id == -1 {
self.current_tool_id = 0;
}
// Ensure tracking arrays are large enough
helpers::ensure_capacity(
self.current_tool_id,
&mut self.prev_tool_call_arr,
&mut self.streamed_args_for_tool,
);
// Store tool call info
let tool_id = self.current_tool_id as usize;
self.prev_tool_call_arr[tool_id] = serde_json::json!({
"name": func_name,
"arguments": {},
});
// Send tool name with empty parameters
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: Some(func_name.to_string()),
parameters: String::new(),
});
} else {
// Invalid function name
tracing::warn!("Invalid function name: {}", func_name);
self.reset_streaming_state();
return Ok(StreamingParseResult {
normal_text: String::new(),
calls,
});
}
} else {
// Function name not complete yet
return Ok(StreamingParseResult {
normal_text: String::new(),
calls,
});
}
}
// Parse parameters incrementally
if self.function_name_sent {
// Extract all complete parameters
let mut new_params = serde_json::Map::new();
for capture in self.param_extractor.captures_iter(invoke_part) {
let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
// Try to parse the value as JSON first, fallback to string
let param_value =
if let Ok(json_val) = serde_json::from_str::(param_value_str) {
json_val
} else {
// Try parsing as Python literal
if param_value_str == "true" || param_value_str == "True" {
Value::Bool(true)
} else if param_value_str == "false" || param_value_str == "False" {
Value::Bool(false)
} else if param_value_str == "null" || param_value_str == "None" {
Value::Null
} else if let Ok(num) = param_value_str.parse::() {
Value::Number(num.into())
} else if let Ok(num) = param_value_str.parse::() {
if let Some(n) = serde_json::Number::from_f64(num) {
Value::Number(n)
} else {
Value::String(param_value_str.to_string())
}
} else {
Value::String(param_value_str.to_string())
}
};
new_params.insert(param_name.to_string(), param_value);
}
// Check if we have new parameters to stream
if new_params != self.current_parameters {
// Build the JSON content without the closing brace for streaming
let diff = if self.current_parameters.is_empty() {
// First parameters - send opening brace and content
let params_content =
serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
if params_content.len() > 2 {
// Send everything except the closing brace
params_content[..params_content.len() - 1].to_string()
} else {
"{".to_string()
}
} else {
// Subsequent parameters - calculate the incremental diff
let old_json = serde_json::to_string(&self.current_parameters)
.unwrap_or_else(|_| "{}".to_string());
let new_json =
serde_json::to_string(&new_params).unwrap_or_else(|_| "{}".to_string());
// Remove closing braces for comparison
let old_without_brace = &old_json[..old_json.len() - 1];
let new_without_brace = &new_json[..new_json.len() - 1];
// The new content should extend the old content
new_without_brace
.strip_prefix(old_without_brace)
.map(|s| s.to_string())
.unwrap_or_default()
};
if !diff.is_empty() {
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: None,
parameters: diff.clone(),
});
let tool_id = self.current_tool_id as usize;
if tool_id < self.streamed_args_for_tool.len() {
self.streamed_args_for_tool[tool_id].push_str(&diff);
}
}
// Update current state
self.current_parameters = new_params.clone();
let tool_id = self.current_tool_id as usize;
if tool_id < self.prev_tool_call_arr.len() {
if let Some(obj) = self.prev_tool_call_arr[tool_id].as_object_mut() {
obj.insert("arguments".to_string(), Value::Object(new_params));
}
}
}
// Check if tool call is complete
if self.buffer.contains(self.tool_call_end) {
// Send closing brace if we've sent any parameters
let tool_id = self.current_tool_id as usize;
if tool_id < self.streamed_args_for_tool.len()
&& !self.streamed_args_for_tool[tool_id].is_empty()
{
calls.push(ToolCallItem {
tool_index: self.current_tool_id as usize,
name: None,
parameters: "}".to_string(),
});
self.streamed_args_for_tool[tool_id].push('}');
}
// Find the end position
if let Some(end_idx) = self.buffer.find(self.tool_call_end) {
// Remove the processed tool call from buffer
self.buffer = self.buffer[end_idx + self.tool_call_end.len()..].to_string();
}
// Reset state for next tool call
self.reset_streaming_state();
self.current_tool_id += 1;
}
}
Ok(StreamingParseResult {
normal_text: String::new(),
calls,
})
}
/// Parse parameters from steptml format
fn parse_steptml_parameters(
&self,
params_text: &str,
) -> ToolParserResult> {
let mut parameters = serde_json::Map::new();
for capture in self.param_extractor.captures_iter(params_text) {
let param_name = capture.get(1).map_or("", |m| m.as_str()).trim();
let param_value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
// Try to parse the value as JSON first, fallback to string
let param_value = if let Ok(json_val) = serde_json::from_str::(param_value_str) {
json_val
} else {
// Try parsing as Python literal
if param_value_str == "true" || param_value_str == "True" {
Value::Bool(true)
} else if param_value_str == "false" || param_value_str == "False" {
Value::Bool(false)
} else if param_value_str == "null" || param_value_str == "None" {
Value::Null
} else if let Ok(num) = param_value_str.parse::() {
Value::Number(num.into())
} else if let Ok(num) = param_value_str.parse::() {
if let Some(n) = serde_json::Number::from_f64(num) {
Value::Number(n)
} else {
Value::String(param_value_str.to_string())
}
} else {
Value::String(param_value_str.to_string())
}
};
parameters.insert(param_name.to_string(), param_value);
}
Ok(parameters)
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult