"vscode:/vscode.git/clone" did not exist on "8ba0f328af92fcf91c0807b26dff81093add1eeb"
Unverified Commit ab926dd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix streaming bugs: empty tool names, state pollution, and panics (#11373)

parent a4b424c6
...@@ -4,7 +4,7 @@ use serde_json::Value; ...@@ -4,7 +4,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
...@@ -70,7 +70,7 @@ impl LlamaParser { ...@@ -70,7 +70,7 @@ impl LlamaParser {
} }
/// Parse a single JSON object into a ToolCall (Llama format: name + parameters) /// Parse a single JSON object into a ToolCall (Llama format: name + parameters)
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> { fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
// Llama format only: {"name": "function_name", "parameters": {...}} // Llama format only: {"name": "function_name", "parameters": {...}}
let name = obj.get("name").and_then(|v| v.as_str()); let name = obj.get("name").and_then(|v| v.as_str());
...@@ -81,7 +81,7 @@ impl LlamaParser { ...@@ -81,7 +81,7 @@ impl LlamaParser {
// Convert parameters to JSON string // Convert parameters to JSON string
let arguments = serde_json::to_string(parameters) let arguments = serde_json::to_string(parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall { Ok(Some(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -95,7 +95,7 @@ impl LlamaParser { ...@@ -95,7 +95,7 @@ impl LlamaParser {
} }
/// Parse semicolon-separated JSON objects /// Parse semicolon-separated JSON objects
fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> { fn parse_semicolon_separated(&self, content: &str) -> ParserResult<Vec<ToolCall>> {
let mut all_tools = Vec::new(); let mut all_tools = Vec::new();
// Split by semicolon and parse each JSON object // Split by semicolon and parse each JSON object
...@@ -131,7 +131,7 @@ impl Default for LlamaParser { ...@@ -131,7 +131,7 @@ impl Default for LlamaParser {
#[async_trait] #[async_trait]
impl ToolParser for LlamaParser { impl ToolParser for LlamaParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Extract normal text and JSON content // Extract normal text and JSON content
let (normal_text, json_content) = let (normal_text, json_content) =
if let Some((normal, json)) = self.extract_content_after_python_tag(text) { if let Some((normal, json)) = self.extract_content_after_python_tag(text) {
...@@ -149,7 +149,7 @@ impl ToolParser for LlamaParser { ...@@ -149,7 +149,7 @@ impl ToolParser for LlamaParser {
} else { } else {
// Try single JSON object // Try single JSON object
let parsed = serde_json::from_str::<Value>(json_content.trim()) let parsed = serde_json::from_str::<Value>(json_content.trim())
.map_err(|e| ToolParserError::ParsingFailed(e.to_string())) .map_err(|e| ParserError::ParsingFailed(e.to_string()))
.and_then(|v| { .and_then(|v| {
self.parse_single_object(&v) self.parse_single_object(&v)
.map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool])) .map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool]))
...@@ -173,7 +173,7 @@ impl ToolParser for LlamaParser { ...@@ -173,7 +173,7 @@ impl ToolParser for LlamaParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Append new text to buffer // Append new text to buffer
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -231,4 +231,14 @@ impl ToolParser for LlamaParser { ...@@ -231,4 +231,14 @@ impl ToolParser for LlamaParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
} }
...@@ -4,7 +4,7 @@ use serde_json::Value; ...@@ -4,7 +4,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
...@@ -111,9 +111,9 @@ impl MistralParser { ...@@ -111,9 +111,9 @@ impl MistralParser {
} }
/// Parse tool calls from a JSON array /// Parse tool calls from a JSON array
fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> { fn parse_json_array(&self, json_str: &str) -> ParserResult<Vec<ToolCall>> {
let value: Value = serde_json::from_str(json_str) let value: Value = serde_json::from_str(json_str)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
let mut tools = Vec::new(); let mut tools = Vec::new();
...@@ -134,7 +134,7 @@ impl MistralParser { ...@@ -134,7 +134,7 @@ impl MistralParser {
} }
/// Parse a single JSON object into a ToolCall /// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> { fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str()); let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name { if let Some(name) = name {
...@@ -144,7 +144,7 @@ impl MistralParser { ...@@ -144,7 +144,7 @@ impl MistralParser {
// Convert arguments to JSON string // Convert arguments to JSON string
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall { Ok(Some(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -166,7 +166,7 @@ impl Default for MistralParser { ...@@ -166,7 +166,7 @@ impl Default for MistralParser {
#[async_trait] #[async_trait]
impl ToolParser for MistralParser { impl ToolParser for MistralParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains Mistral format // Check if text contains Mistral format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
...@@ -199,7 +199,7 @@ impl ToolParser for MistralParser { ...@@ -199,7 +199,7 @@ impl ToolParser for MistralParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Append new text to buffer // Append new text to buffer
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -256,4 +256,14 @@ impl ToolParser for MistralParser { ...@@ -256,4 +256,14 @@ impl ToolParser for MistralParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
} }
...@@ -18,7 +18,7 @@ use std::sync::OnceLock; ...@@ -18,7 +18,7 @@ use std::sync::OnceLock;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
...@@ -74,7 +74,7 @@ impl PythonicParser { ...@@ -74,7 +74,7 @@ impl PythonicParser {
.replace("<|python_end|>", "") .replace("<|python_end|>", "")
} }
fn parse_tool_call_block(&self, block: &str) -> ToolParserResult<Vec<ToolCall>> { fn parse_tool_call_block(&self, block: &str) -> ParserResult<Vec<ToolCall>> {
let expr = parse_python_expression(block)?; let expr = parse_python_expression(block)?;
match expr { match expr {
Expr::List(list_expr) => list_expr Expr::List(list_expr) => list_expr
...@@ -83,7 +83,7 @@ impl PythonicParser { ...@@ -83,7 +83,7 @@ impl PythonicParser {
.enumerate() .enumerate()
.map(|(idx, call_expr)| build_tool_call(call_expr, idx)) .map(|(idx, call_expr)| build_tool_call(call_expr, idx))
.collect(), .collect(),
_ => Err(ToolParserError::ParsingFailed( _ => Err(ParserError::ParsingFailed(
"Expected a list of function calls in pythonic tool call".to_string(), "Expected a list of function calls in pythonic tool call".to_string(),
)), )),
} }
...@@ -92,7 +92,7 @@ impl PythonicParser { ...@@ -92,7 +92,7 @@ impl PythonicParser {
#[async_trait] #[async_trait]
impl ToolParser for PythonicParser { impl ToolParser for PythonicParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
let cleaned = Self::strip_special_tokens(text); let cleaned = Self::strip_special_tokens(text);
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) { if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
...@@ -120,7 +120,7 @@ impl ToolParser for PythonicParser { ...@@ -120,7 +120,7 @@ impl ToolParser for PythonicParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let cleaned = Self::strip_special_tokens(&self.buffer); let cleaned = Self::strip_special_tokens(&self.buffer);
...@@ -232,23 +232,23 @@ fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> { ...@@ -232,23 +232,23 @@ fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> {
None // No matching bracket found None // No matching bracket found
} }
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> { fn parse_python_expression(source: &str) -> ParserResult<Expr> {
let module = parse(source, Mode::Expression, "<pythonic_tool_call>") let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?; .map_err(|err| ParserError::ParsingFailed(err.to_string()))?;
match module { match module {
Mod::Expression(expr_mod) => Ok(*expr_mod.body), Mod::Expression(expr_mod) => Ok(*expr_mod.body),
_ => Err(ToolParserError::ParsingFailed( _ => Err(ParserError::ParsingFailed(
"Expected a Python expression".to_string(), "Expected a Python expression".to_string(),
)), )),
} }
} }
fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> { fn build_tool_call(expr: Expr, _index: usize) -> ParserResult<ToolCall> {
match expr { match expr {
Expr::Call(call_expr) => { Expr::Call(call_expr) => {
if !call_expr.args.is_empty() { if !call_expr.args.is_empty() {
return Err(ToolParserError::ParsingFailed( return Err(ParserError::ParsingFailed(
"Positional arguments are not supported in pythonic tool calls".to_string(), "Positional arguments are not supported in pythonic tool calls".to_string(),
)); ));
} }
...@@ -256,7 +256,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> { ...@@ -256,7 +256,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
let function_name = match *call_expr.func { let function_name = match *call_expr.func {
Expr::Name(name_expr) => name_expr.id.to_string(), Expr::Name(name_expr) => name_expr.id.to_string(),
_ => { _ => {
return Err(ToolParserError::ParsingFailed( return Err(ParserError::ParsingFailed(
"Unsupported function reference in pythonic tool call".to_string(), "Unsupported function reference in pythonic tool call".to_string(),
)) ))
} }
...@@ -265,7 +265,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> { ...@@ -265,7 +265,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
let mut arguments_map = Map::with_capacity(call_expr.keywords.len()); let mut arguments_map = Map::with_capacity(call_expr.keywords.len());
for keyword in call_expr.keywords { for keyword in call_expr.keywords {
let arg_name = keyword.arg.ok_or_else(|| { let arg_name = keyword.arg.ok_or_else(|| {
ToolParserError::ParsingFailed( ParserError::ParsingFailed(
"pythonic tool calls do not support **kwargs".to_string(), "pythonic tool calls do not support **kwargs".to_string(),
) )
})?; })?;
...@@ -283,13 +283,13 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> { ...@@ -283,13 +283,13 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
}, },
}) })
} }
_ => Err(ToolParserError::ParsingFailed( _ => Err(ParserError::ParsingFailed(
"Expected function calls inside pythonic tool call list".to_string(), "Expected function calls inside pythonic tool call list".to_string(),
)), )),
} }
} }
fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> { fn expression_to_json(expr: &Expr) -> ParserResult<Value> {
match expr { match expr {
Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value), Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value),
Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array), Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array),
...@@ -300,81 +300,75 @@ fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> { ...@@ -300,81 +300,75 @@ fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> {
Expr::UnaryOp(unary_expr) => match unary_expr.op { Expr::UnaryOp(unary_expr) => match unary_expr.op {
UnaryOp::USub => match unary_expr.operand.as_ref() { UnaryOp::USub => match unary_expr.operand.as_ref() {
Expr::Constant(const_expr) => negate_constant(&const_expr.value), Expr::Constant(const_expr) => negate_constant(&const_expr.value),
_ => Err(ToolParserError::ParsingFailed( _ => Err(ParserError::ParsingFailed(
"Unsupported unary operand in pythonic tool call".to_string(), "Unsupported unary operand in pythonic tool call".to_string(),
)), )),
}, },
UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()), UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()),
_ => Err(ToolParserError::ParsingFailed(format!( _ => Err(ParserError::ParsingFailed(format!(
"Unsupported unary operator in pythonic tool call: {:?}", "Unsupported unary operator in pythonic tool call: {:?}",
unary_expr.op unary_expr.op
))), ))),
}, },
Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())), Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())),
_ => Err(ToolParserError::ParsingFailed(format!( _ => Err(ParserError::ParsingFailed(format!(
"Unsupported expression in pythonic tool call: {:?}", "Unsupported expression in pythonic tool call: {:?}",
expr expr
))), ))),
} }
} }
fn constant_to_json(constant: &Constant) -> ToolParserResult<Value> { fn constant_to_json(constant: &Constant) -> ParserResult<Value> {
match constant { match constant {
Constant::None => Ok(Value::Null), Constant::None => Ok(Value::Null),
Constant::Bool(b) => Ok(Value::Bool(*b)), Constant::Bool(b) => Ok(Value::Bool(*b)),
Constant::Int(value) => Ok(integer_constant_to_value(value, false)), Constant::Int(value) => Ok(integer_constant_to_value(value, false)),
Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| { Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| {
ToolParserError::ParsingFailed( ParserError::ParsingFailed("Invalid float literal in pythonic tool call".to_string())
"Invalid float literal in pythonic tool call".to_string(),
)
}), }),
Constant::Str(s) => Ok(Value::String(s.clone())), Constant::Str(s) => Ok(Value::String(s.clone())),
Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())), Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())),
Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array), Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array),
Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed( Constant::Ellipsis | Constant::Complex { .. } => Err(ParserError::ParsingFailed(
"Unsupported literal in pythonic tool call".to_string(), "Unsupported literal in pythonic tool call".to_string(),
)), )),
} }
} }
fn negate_constant(constant: &Constant) -> ToolParserResult<Value> { fn negate_constant(constant: &Constant) -> ParserResult<Value> {
match constant { match constant {
Constant::Int(value) => Ok(integer_constant_to_value(value, true)), Constant::Int(value) => Ok(integer_constant_to_value(value, true)),
Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| { Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| {
ToolParserError::ParsingFailed( ParserError::ParsingFailed("Invalid float literal in pythonic tool call".to_string())
"Invalid float literal in pythonic tool call".to_string(),
)
}), }),
_ => Err(ToolParserError::ParsingFailed( _ => Err(ParserError::ParsingFailed(
"Unsupported unary operand in pythonic tool call".to_string(), "Unsupported unary operand in pythonic tool call".to_string(),
)), )),
} }
} }
fn value_to_key_string(value: Value) -> ToolParserResult<String> { fn value_to_key_string(value: Value) -> ParserResult<String> {
match value { match value {
Value::String(s) => Ok(s), Value::String(s) => Ok(s),
Value::Number(num) => Ok(num.to_string()), Value::Number(num) => Ok(num.to_string()),
Value::Bool(b) => Ok(b.to_string()), Value::Bool(b) => Ok(b.to_string()),
Value::Null => Ok("null".to_string()), Value::Null => Ok("null".to_string()),
other => Err(ToolParserError::ParsingFailed(format!( other => Err(ParserError::ParsingFailed(format!(
"Unsupported key type in pythonic tool call: {:?}", "Unsupported key type in pythonic tool call: {:?}",
other other
))), ))),
} }
} }
fn collect_sequence(elements: &[Expr]) -> ToolParserResult<Vec<Value>> { fn collect_sequence(elements: &[Expr]) -> ParserResult<Vec<Value>> {
elements.iter().map(expression_to_json).collect() elements.iter().map(expression_to_json).collect()
} }
fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<String, Value>> { fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ParserResult<Map<String, Value>> {
let mut map = Map::with_capacity(keys.len()); let mut map = Map::with_capacity(keys.len());
for (key_expr, value_expr) in keys.iter().zip(values.iter()) { for (key_expr, value_expr) in keys.iter().zip(values.iter()) {
let key_expr = key_expr.as_ref().ok_or_else(|| { let key_expr = key_expr.as_ref().ok_or_else(|| {
ToolParserError::ParsingFailed( ParserError::ParsingFailed("pythonic tool calls do not support **kwargs".to_string())
"pythonic tool calls do not support **kwargs".to_string(),
)
})?; })?;
let key_value = expression_to_json(key_expr)?; let key_value = expression_to_json(key_expr)?;
let key = value_to_key_string(key_value)?; let key = value_to_key_string(key_value)?;
...@@ -384,7 +378,7 @@ fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map< ...@@ -384,7 +378,7 @@ fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<
Ok(map) Ok(map)
} }
fn constant_tuple_to_array(values: &[Constant]) -> ToolParserResult<Vec<Value>> { fn constant_tuple_to_array(values: &[Constant]) -> ParserResult<Vec<Value>> {
values.iter().map(constant_to_json).collect() values.iter().map(constant_to_json).collect()
} }
......
...@@ -5,7 +5,7 @@ use serde_json::Value; ...@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
...@@ -76,7 +76,7 @@ impl QwenParser { ...@@ -76,7 +76,7 @@ impl QwenParser {
} }
/// Parse a single JSON object into a ToolCall /// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> { fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str()); let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name { if let Some(name) = name {
...@@ -86,7 +86,7 @@ impl QwenParser { ...@@ -86,7 +86,7 @@ impl QwenParser {
// Convert arguments to JSON string // Convert arguments to JSON string
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall { Ok(Some(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -108,7 +108,7 @@ impl Default for QwenParser { ...@@ -108,7 +108,7 @@ impl Default for QwenParser {
#[async_trait] #[async_trait]
impl ToolParser for QwenParser { impl ToolParser for QwenParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains Qwen format // Check if text contains Qwen format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
...@@ -123,7 +123,7 @@ impl ToolParser for QwenParser { ...@@ -123,7 +123,7 @@ impl ToolParser for QwenParser {
for captures in self.extractor.captures_iter(text) { for captures in self.extractor.captures_iter(text) {
if let Some(json_str) = captures.get(1) { if let Some(json_str) = captures.get(1) {
let parsed = serde_json::from_str::<Value>(json_str.as_str().trim()) let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
.map_err(|e| ToolParserError::ParsingFailed(e.to_string())) .map_err(|e| ParserError::ParsingFailed(e.to_string()))
.and_then(|v| self.parse_single_object(&v)); .and_then(|v| self.parse_single_object(&v));
match parsed { match parsed {
...@@ -149,7 +149,7 @@ impl ToolParser for QwenParser { ...@@ -149,7 +149,7 @@ impl ToolParser for QwenParser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
// Append new text to buffer // Append new text to buffer
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
...@@ -240,4 +240,14 @@ impl ToolParser for QwenParser { ...@@ -240,4 +240,14 @@ impl ToolParser for QwenParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
} }
...@@ -6,7 +6,7 @@ use std::collections::HashMap; ...@@ -6,7 +6,7 @@ use std::collections::HashMap;
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
parsers::helpers, parsers::helpers,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
...@@ -108,7 +108,7 @@ impl Step3Parser { ...@@ -108,7 +108,7 @@ impl Step3Parser {
fn parse_partial_tool_call( fn parse_partial_tool_call(
&mut self, &mut self,
tool_indices: &HashMap<String, usize>, tool_indices: &HashMap<String, usize>,
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
let mut calls = Vec::new(); let mut calls = Vec::new();
// Check if we have tool_sep (means we're past the type declaration) // Check if we have tool_sep (means we're past the type declaration)
...@@ -321,7 +321,7 @@ impl Step3Parser { ...@@ -321,7 +321,7 @@ impl Step3Parser {
fn parse_steptml_parameters( fn parse_steptml_parameters(
&self, &self,
params_text: &str, params_text: &str,
) -> ToolParserResult<serde_json::Map<String, Value>> { ) -> ParserResult<serde_json::Map<String, Value>> {
let mut parameters = serde_json::Map::new(); let mut parameters = serde_json::Map::new();
for capture in self.param_extractor.captures_iter(params_text) { for capture in self.param_extractor.captures_iter(params_text) {
...@@ -359,7 +359,7 @@ impl Step3Parser { ...@@ -359,7 +359,7 @@ impl Step3Parser {
} }
/// Parse a single tool call block /// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> { fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
// Check if it contains function marker and tool separator // Check if it contains function marker and tool separator
if !block.contains("function") || !block.contains("<|tool_sep|>") { if !block.contains("function") || !block.contains("<|tool_sep|>") {
return Ok(None); return Ok(None);
...@@ -393,7 +393,7 @@ impl Step3Parser { ...@@ -393,7 +393,7 @@ impl Step3Parser {
let parameters = self.parse_steptml_parameters(params_text)?; let parameters = self.parse_steptml_parameters(params_text)?;
let arguments_str = serde_json::to_string(&parameters) let arguments_str = serde_json::to_string(&parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall { Ok(Some(ToolCall {
function: FunctionCall { function: FunctionCall {
...@@ -415,7 +415,7 @@ impl Default for Step3Parser { ...@@ -415,7 +415,7 @@ impl Default for Step3Parser {
#[async_trait] #[async_trait]
impl ToolParser for Step3Parser { impl ToolParser for Step3Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
...@@ -449,7 +449,7 @@ impl ToolParser for Step3Parser { ...@@ -449,7 +449,7 @@ impl ToolParser for Step3Parser {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> { ) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk); self.buffer.push_str(chunk);
// Build tool indices for validation // Build tool indices for validation
...@@ -555,4 +555,20 @@ impl ToolParser for Step3Parser { ...@@ -555,4 +555,20 @@ impl ToolParser for Step3Parser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool) helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
} }
fn reset(&mut self) {
// Reset standard state
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.streamed_args_for_tool.clear();
// Reset Step3-specific fields
self.in_tool_block = false;
self.tool_block_finished = false;
self.current_function_name.clear();
self.current_parameters.clear();
self.in_tool_call = false;
self.function_name_sent = false;
}
} }
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult}, errors::{ParserError, ParserResult},
traits::PartialJsonParser, traits::PartialJsonParser,
}; };
use serde_json::{Map, Value}; use serde_json::{Map, Value};
...@@ -22,8 +22,22 @@ impl PartialJson { ...@@ -22,8 +22,22 @@ impl PartialJson {
} }
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes /// Parse potentially incomplete JSON, returning parsed value and consumed bytes
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> { ///
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete); /// # Arguments
/// * `input` - The JSON string to parse
/// * `allow_partial_strings` - When false, incomplete strings cause parsing to stop
/// (matches Python's Allow.ALL & ~Allow.STR behavior)
pub fn parse_value(
&self,
input: &str,
allow_partial_strings: bool,
) -> ParserResult<(Value, usize)> {
let mut parser = Parser::new(
input,
self.max_depth,
self.allow_incomplete,
allow_partial_strings,
);
let value = parser.parse_value(0)?; let value = parser.parse_value(0)?;
Ok((value, parser.position)) Ok((value, parser.position))
} }
...@@ -36,8 +50,9 @@ impl Default for PartialJson { ...@@ -36,8 +50,9 @@ impl Default for PartialJson {
} }
impl PartialJsonParser for PartialJson { impl PartialJsonParser for PartialJson {
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> { fn parse(&self, input: &str) -> ParserResult<(Value, usize)> {
self.parse_value(input) // Default to allowing partial strings
self.parse_value(input, true)
} }
fn is_complete(&self, input: &str) -> bool { fn is_complete(&self, input: &str) -> bool {
...@@ -56,15 +71,22 @@ struct Parser<'a> { ...@@ -56,15 +71,22 @@ struct Parser<'a> {
position: usize, position: usize,
max_depth: usize, max_depth: usize,
allow_incomplete: bool, allow_incomplete: bool,
allow_partial_strings: bool,
} }
impl<'a> Parser<'a> { impl<'a> Parser<'a> {
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self { fn new(
input: &'a str,
max_depth: usize,
allow_incomplete: bool,
allow_partial_strings: bool,
) -> Self {
Self { Self {
chars: input.chars().peekable(), chars: input.chars().peekable(),
position: 0, position: 0,
max_depth, max_depth,
allow_incomplete, allow_incomplete,
allow_partial_strings,
} }
} }
...@@ -88,9 +110,9 @@ impl<'a> Parser<'a> { ...@@ -88,9 +110,9 @@ impl<'a> Parser<'a> {
} }
} }
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> { fn parse_value(&mut self, depth: usize) -> ParserResult<Value> {
if depth > self.max_depth { if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth)); return Err(ParserError::DepthExceeded(self.max_depth));
} }
self.skip_whitespace(); self.skip_whitespace();
...@@ -106,17 +128,15 @@ impl<'a> Parser<'a> { ...@@ -106,17 +128,15 @@ impl<'a> Parser<'a> {
if self.allow_incomplete { if self.allow_incomplete {
Ok(Value::Null) Ok(Value::Null)
} else { } else {
Err(ToolParserError::ParsingFailed( Err(ParserError::ParsingFailed("Unexpected character".into()))
"Unexpected character".into(),
))
} }
} }
} }
} }
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> { fn parse_object(&mut self, depth: usize) -> ParserResult<Value> {
if depth > self.max_depth { if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth)); return Err(ParserError::DepthExceeded(self.max_depth));
} }
let mut object = Map::new(); let mut object = Map::new();
...@@ -140,7 +160,7 @@ impl<'a> Parser<'a> { ...@@ -140,7 +160,7 @@ impl<'a> Parser<'a> {
return Ok(Value::Object(object)); return Ok(Value::Object(object));
} }
Err(e) => return Err(e), Err(e) => return Err(e),
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())), _ => return Err(ParserError::ParsingFailed("Expected string key".into())),
}; };
self.skip_whitespace(); self.skip_whitespace();
...@@ -152,7 +172,7 @@ impl<'a> Parser<'a> { ...@@ -152,7 +172,7 @@ impl<'a> Parser<'a> {
object.insert(key, Value::Null); object.insert(key, Value::Null);
return Ok(Value::Object(object)); return Ok(Value::Object(object));
} }
return Err(ToolParserError::ParsingFailed("Expected ':'".into())); return Err(ParserError::ParsingFailed("Expected ':'".into()));
} }
self.advance(); self.advance();
self.skip_whitespace(); self.skip_whitespace();
...@@ -161,8 +181,13 @@ impl<'a> Parser<'a> { ...@@ -161,8 +181,13 @@ impl<'a> Parser<'a> {
let value = match self.parse_value(depth) { let value = match self.parse_value(depth) {
Ok(v) => v, Ok(v) => v,
Err(_) if self.allow_incomplete => { Err(_) if self.allow_incomplete => {
// When allow_partial_strings is false, don't add the key with Null
// Just return the object without this incomplete key-value pair
// This matches Python's behavior: Allow.ALL & ~Allow.STR
if self.allow_partial_strings {
// Add null for incomplete value // Add null for incomplete value
object.insert(key, Value::Null); object.insert(key, Value::Null);
}
return Ok(Value::Object(object)); return Ok(Value::Object(object));
} }
Err(e) => return Err(e), Err(e) => return Err(e),
...@@ -192,15 +217,15 @@ impl<'a> Parser<'a> { ...@@ -192,15 +217,15 @@ impl<'a> Parser<'a> {
if self.allow_incomplete { if self.allow_incomplete {
return Ok(Value::Object(object)); return Ok(Value::Object(object));
} }
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into())); return Err(ParserError::ParsingFailed("Expected ',' or '}'".into()));
} }
} }
} }
} }
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> { fn parse_array(&mut self, depth: usize) -> ParserResult<Value> {
if depth > self.max_depth { if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth)); return Err(ParserError::DepthExceeded(self.max_depth));
} }
let mut array = Vec::new(); let mut array = Vec::new();
...@@ -249,15 +274,15 @@ impl<'a> Parser<'a> { ...@@ -249,15 +274,15 @@ impl<'a> Parser<'a> {
if self.allow_incomplete { if self.allow_incomplete {
return Ok(Value::Array(array)); return Ok(Value::Array(array));
} }
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into())); return Err(ParserError::ParsingFailed("Expected ',' or ']'".into()));
} }
} }
} }
} }
fn parse_string(&mut self) -> ToolParserResult<Value> { fn parse_string(&mut self) -> ParserResult<Value> {
if self.peek() != Some('"') { if self.peek() != Some('"') {
return Err(ToolParserError::ParsingFailed("Expected '\"'".into())); return Err(ParserError::ParsingFailed("Expected '\"'".into()));
} }
// Consume opening quote // Consume opening quote
...@@ -301,14 +326,14 @@ impl<'a> Parser<'a> { ...@@ -301,14 +326,14 @@ impl<'a> Parser<'a> {
} }
// Incomplete string // Incomplete string
if self.allow_incomplete { if self.allow_incomplete && self.allow_partial_strings {
Ok(Value::String(string)) Ok(Value::String(string))
} else { } else {
Err(ToolParserError::ParsingFailed("Unterminated string".into())) Err(ParserError::ParsingFailed("Unterminated string".into()))
} }
} }
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> { fn parse_unicode_escape(&mut self) -> ParserResult<char> {
let mut hex = String::new(); let mut hex = String::new();
for _ in 0..4 { for _ in 0..4 {
if let Some(ch) = self.peek() { if let Some(ch) = self.peek() {
...@@ -327,17 +352,17 @@ impl<'a> Parser<'a> { ...@@ -327,17 +352,17 @@ impl<'a> Parser<'a> {
u32::from_str_radix(&hex, 16) u32::from_str_radix(&hex, 16)
.ok() .ok()
.and_then(char::from_u32) .and_then(char::from_u32)
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into())) .ok_or_else(|| ParserError::ParsingFailed("Invalid unicode escape".into()))
} else if self.allow_incomplete { } else if self.allow_incomplete {
Ok('\u{FFFD}') // Replacement character Ok('\u{FFFD}') // Replacement character
} else { } else {
Err(ToolParserError::ParsingFailed( Err(ParserError::ParsingFailed(
"Incomplete unicode escape".into(), "Incomplete unicode escape".into(),
)) ))
} }
} }
fn parse_number(&mut self) -> ToolParserResult<Value> { fn parse_number(&mut self) -> ParserResult<Value> {
let mut number = String::new(); let mut number = String::new();
// Handle negative sign // Handle negative sign
...@@ -410,11 +435,11 @@ impl<'a> Parser<'a> { ...@@ -410,11 +435,11 @@ impl<'a> Parser<'a> {
} else if self.allow_incomplete { } else if self.allow_incomplete {
Ok(Value::Number(serde_json::Number::from(0))) Ok(Value::Number(serde_json::Number::from(0)))
} else { } else {
Err(ToolParserError::ParsingFailed("Invalid number".into())) Err(ParserError::ParsingFailed("Invalid number".into()))
} }
} }
fn parse_bool(&mut self) -> ToolParserResult<Value> { fn parse_bool(&mut self) -> ParserResult<Value> {
let mut word = String::new(); let mut word = String::new();
// Peek at upcoming characters to validate it looks like a boolean // Peek at upcoming characters to validate it looks like a boolean
...@@ -435,7 +460,7 @@ impl<'a> Parser<'a> { ...@@ -435,7 +460,7 @@ impl<'a> Parser<'a> {
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word))); || (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
if !is_valid { if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid boolean".into())); return Err(ParserError::ParsingFailed("Invalid boolean".into()));
} }
// Now actually consume the characters // Now actually consume the characters
...@@ -458,14 +483,14 @@ impl<'a> Parser<'a> { ...@@ -458,14 +483,14 @@ impl<'a> Parser<'a> {
} else if "false".starts_with(partial) { } else if "false".starts_with(partial) {
Ok(Value::Bool(false)) Ok(Value::Bool(false))
} else { } else {
Err(ToolParserError::ParsingFailed("Invalid boolean".into())) Err(ParserError::ParsingFailed("Invalid boolean".into()))
} }
} }
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())), _ => Err(ParserError::ParsingFailed("Invalid boolean".into())),
} }
} }
fn parse_null(&mut self) -> ToolParserResult<Value> { fn parse_null(&mut self) -> ParserResult<Value> {
let mut word = String::new(); let mut word = String::new();
// Peek at upcoming characters to validate it looks like "null" // Peek at upcoming characters to validate it looks like "null"
...@@ -484,7 +509,7 @@ impl<'a> Parser<'a> { ...@@ -484,7 +509,7 @@ impl<'a> Parser<'a> {
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word)); let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
if !is_valid { if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid null".into())); return Err(ParserError::ParsingFailed("Invalid null".into()));
} }
// Now actually consume the characters // Now actually consume the characters
...@@ -501,7 +526,7 @@ impl<'a> Parser<'a> { ...@@ -501,7 +526,7 @@ impl<'a> Parser<'a> {
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) { if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
Ok(Value::Null) Ok(Value::Null)
} else { } else {
Err(ToolParserError::ParsingFailed("Invalid null".into())) Err(ParserError::ParsingFailed("Invalid null".into()))
} }
} }
} }
......
...@@ -7,7 +7,7 @@ use crate::tool_parser::traits::ToolParser; ...@@ -7,7 +7,7 @@ use crate::tool_parser::traits::ToolParser;
#[tokio::test] #[tokio::test]
async fn test_tool_parser_factory() { async fn test_tool_parser_factory() {
let factory = ToolParserFactory::new(); let factory = ParserFactory::new();
// Test that we can get a pooled parser // Test that we can get a pooled parser
let pooled_parser = factory.get_pooled("gpt-4"); let pooled_parser = factory.get_pooled("gpt-4");
...@@ -17,7 +17,7 @@ async fn test_tool_parser_factory() { ...@@ -17,7 +17,7 @@ async fn test_tool_parser_factory() {
#[tokio::test] #[tokio::test]
async fn test_tool_parser_factory_model_mapping() { async fn test_tool_parser_factory_model_mapping() {
let factory = ToolParserFactory::new(); let factory = ParserFactory::new();
// Test model mapping // Test model mapping
factory.registry().map_model("test-model", "json"); factory.registry().map_model("test-model", "json");
...@@ -54,22 +54,22 @@ fn test_partial_json_parser() { ...@@ -54,22 +54,22 @@ fn test_partial_json_parser() {
let parser = PartialJson::default(); let parser = PartialJson::default();
let input = r#"{"name": "test", "value": 42}"#; let input = r#"{"name": "test", "value": 42}"#;
let (value, consumed) = parser.parse_value(input).unwrap(); let (value, consumed) = parser.parse_value(input, true).unwrap();
assert_eq!(value["name"], "test"); assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42); assert_eq!(value["value"], 42);
assert_eq!(consumed, input.len()); assert_eq!(consumed, input.len());
let input = r#"{"name": "test", "value": "#; let input = r#"{"name": "test", "value": "#;
let (value, _consumed) = parser.parse_value(input).unwrap(); let (value, _consumed) = parser.parse_value(input, true).unwrap();
assert_eq!(value["name"], "test"); assert_eq!(value["name"], "test");
assert!(value["value"].is_null()); assert!(value["value"].is_null());
let input = r#"{"name": "tes"#; let input = r#"{"name": "tes"#;
let (value, _consumed) = parser.parse_value(input).unwrap(); let (value, _consumed) = parser.parse_value(input, true).unwrap();
assert_eq!(value["name"], "tes"); assert_eq!(value["name"], "tes");
let input = r#"[1, 2, "#; let input = r#"[1, 2, "#;
let (value, _consumed) = parser.parse_value(input).unwrap(); let (value, _consumed) = parser.parse_value(input, true).unwrap();
assert!(value.is_array()); assert!(value.is_array());
assert_eq!(value[0], 1); assert_eq!(value[0], 1);
assert_eq!(value[1], 2); assert_eq!(value[1], 2);
...@@ -83,17 +83,17 @@ fn test_partial_json_depth_limit() { ...@@ -83,17 +83,17 @@ fn test_partial_json_depth_limit() {
// This should work (simple object) // This should work (simple object)
let input = r#"{"a": 1}"#; let input = r#"{"a": 1}"#;
let result = parser.parse_value(input); let result = parser.parse_value(input, true);
assert!(result.is_ok()); assert!(result.is_ok());
// This should work (nested to depth 3) // This should work (nested to depth 3)
let input = r#"{"a": {"b": {"c": 1}}}"#; let input = r#"{"a": {"b": {"c": 1}}}"#;
let result = parser.parse_value(input); let result = parser.parse_value(input, true);
assert!(result.is_ok()); assert!(result.is_ok());
// This should fail (nested to depth 4, exceeds limit) // This should fail (nested to depth 4, exceeds limit)
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#; let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
let result = parser.parse_value(input); let result = parser.parse_value(input, true);
assert!(result.is_err()); assert!(result.is_err());
} }
...@@ -244,7 +244,7 @@ fn test_json_parser_format_detection() { ...@@ -244,7 +244,7 @@ fn test_json_parser_format_detection() {
#[tokio::test] #[tokio::test]
async fn test_factory_with_json_parser() { async fn test_factory_with_json_parser() {
let factory = ToolParserFactory::new(); let factory = ParserFactory::new();
// Should get JSON parser for OpenAI models // Should get JSON parser for OpenAI models
let pooled_parser = factory.get_pooled("gpt-4-turbo"); let pooled_parser = factory.get_pooled("gpt-4-turbo");
......
use crate::protocols::spec::Tool; use crate::protocols::spec::Tool;
use crate::tool_parser::{ use crate::tool_parser::{
errors::ToolParserResult, errors::ParserResult,
types::{StreamingParseResult, ToolCall}, types::{StreamingParseResult, ToolCall},
}; };
use async_trait::async_trait; use async_trait::async_trait;
...@@ -10,7 +10,7 @@ use async_trait::async_trait; ...@@ -10,7 +10,7 @@ use async_trait::async_trait;
pub trait ToolParser: Send + Sync { pub trait ToolParser: Send + Sync {
/// Parse complete tool calls from final output /// Parse complete tool calls from final output
/// Returns (remaining_normal_text, tool_calls) tuple /// Returns (remaining_normal_text, tool_calls) tuple
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)>; async fn parse_complete(&self, output: &str) -> ParserResult<(String, Vec<ToolCall>)>;
/// Parse tool calls from model output (streaming) /// Parse tool calls from model output (streaming)
/// Parsers now maintain internal state, so self is mutable /// Parsers now maintain internal state, so self is mutable
...@@ -22,7 +22,7 @@ pub trait ToolParser: Send + Sync { ...@@ -22,7 +22,7 @@ pub trait ToolParser: Send + Sync {
&mut self, &mut self,
chunk: &str, chunk: &str,
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult>; ) -> ParserResult<StreamingParseResult>;
/// Check if text contains tool calls in this parser's format /// Check if text contains tool calls in this parser's format
fn has_tool_markers(&self, text: &str) -> bool; fn has_tool_markers(&self, text: &str) -> bool;
...@@ -38,12 +38,18 @@ pub trait ToolParser: Send + Sync { ...@@ -38,12 +38,18 @@ pub trait ToolParser: Send + Sync {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
None None
} }
/// Reset the parser state for reuse across requests.
/// This should clear all buffers and reset state to initial values.
fn reset(&mut self) {
// Default no-op implementation
}
} }
/// Trait for partial JSON parsing /// Trait for partial JSON parsing
pub trait PartialJsonParser: Send + Sync { pub trait PartialJsonParser: Send + Sync {
/// Parse potentially incomplete JSON /// Parse potentially incomplete JSON
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>; fn parse(&self, input: &str) -> ParserResult<(serde_json::Value, usize)>;
/// Check if JSON is complete /// Check if JSON is complete
fn is_complete(&self, input: &str) -> bool; fn is_complete(&self, input: &str) -> bool;
...@@ -55,10 +61,7 @@ pub trait PartialJsonParser: Send + Sync { ...@@ -55,10 +61,7 @@ pub trait PartialJsonParser: Send + Sync {
#[async_trait] #[async_trait]
pub trait TokenToolParser: ToolParser { pub trait TokenToolParser: ToolParser {
/// Parse complete tool calls when provided with raw token IDs. /// Parse complete tool calls when provided with raw token IDs.
async fn parse_complete_tokens( async fn parse_complete_tokens(&self, tokens: &[u32]) -> ParserResult<(String, Vec<ToolCall>)>;
&self,
tokens: &[u32],
) -> ToolParserResult<(String, Vec<ToolCall>)>;
/// Streaming parser entrypoint for token chunks. /// Streaming parser entrypoint for token chunks.
/// Parsers maintain internal state, so self is mutable /// Parsers maintain internal state, so self is mutable
...@@ -66,5 +69,5 @@ pub trait TokenToolParser: ToolParser { ...@@ -66,5 +69,5 @@ pub trait TokenToolParser: ToolParser {
&mut self, &mut self,
tokens: &[u32], tokens: &[u32],
tools: &[Tool], tools: &[Tool],
) -> ToolParserResult<StreamingParseResult>; ) -> ParserResult<StreamingParseResult>;
} }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
pub mod mock_mcp_server; pub mod mock_mcp_server;
pub mod mock_openai_server; pub mod mock_openai_server;
pub mod mock_worker; pub mod mock_worker;
pub mod streaming_helpers;
pub mod test_app; pub mod test_app;
use serde_json::json; use serde_json::json;
......
//! Streaming Test Helpers
//!
//! Utilities for creating realistic streaming chunks that simulate
//! how LLM tokens actually arrive (1-5 characters at a time).
/// Split input into realistic char-level chunks (2-3 chars each for determinism)
pub fn create_realistic_chunks(input: &str) -> Vec<String> {
let mut chunks = Vec::new();
let chars: Vec<char> = input.chars().collect();
let mut i = 0;
while i < chars.len() {
// Take 2-3 characters at a time (deterministic for testing)
let chunk_size = if i + 3 <= chars.len() && chars[i].is_ascii_alphanumeric() {
3 // Longer chunks for alphanumeric sequences
} else {
2 // Shorter chunks for special characters
};
let end = (i + chunk_size).min(chars.len());
let chunk: String = chars[i..end].iter().collect();
chunks.push(chunk);
i = end;
}
chunks
}
/// Split input at strategic positions to test edge cases
/// This creates chunks that break at critical positions like after quotes, colons, etc.
pub fn create_strategic_chunks(input: &str) -> Vec<String> {
let mut chunks = Vec::new();
let mut current = String::new();
let chars: Vec<char> = input.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
current.push(ch);
// Break after strategic characters
let should_break = matches!(ch, '"' | ':' | ',' | '{' | '}' | '[' | ']')
|| (i > 0 && chars[i-1] == '"' && ch == ' ') // Space after quote
|| current.len() >= 5; // Max 5 chars per chunk
if should_break && !current.is_empty() {
chunks.push(current.clone());
current.clear();
}
}
if !current.is_empty() {
chunks.push(current);
}
chunks
}
/// Create the bug scenario chunks: `{"name": "` arrives in parts
pub fn create_bug_scenario_chunks() -> Vec<&'static str> {
vec![
r#"{"#,
r#"""#,
r#"name"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#, // Bug occurs here: parser has {"name": "
r#"search"#, // Use valid tool name
r#"""#,
r#","#,
r#" "#,
r#"""#,
r#"arguments"#,
r#"""#,
r#":"#,
r#" "#,
r#"{"#,
r#"""#,
r#"query"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#,
r#"test query"#,
r#"""#,
r#"}"#,
r#"}"#,
]
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_realistic_chunks() {
let input = r#"{"name": "test"}"#;
let chunks = create_realistic_chunks(input);
// Should have multiple chunks
assert!(chunks.len() > 3);
// Reconstructed should equal original
let reconstructed: String = chunks.join("");
assert_eq!(reconstructed, input);
}
#[test]
fn test_strategic_chunks_breaks_after_quotes() {
let input = r#"{"name": "value"}"#;
let chunks = create_strategic_chunks(input);
// Should break after quotes and colons
assert!(chunks.iter().any(|c| c.ends_with('"')));
assert!(chunks.iter().any(|c| c.ends_with(':')));
// Reconstructed should equal original
let reconstructed: String = chunks.join("");
assert_eq!(reconstructed, input);
}
#[test]
fn test_bug_scenario_chunks() {
let chunks = create_bug_scenario_chunks();
let reconstructed: String = chunks.join("");
// Should reconstruct to valid JSON
assert!(reconstructed.contains(r#"{"name": "search""#));
// The critical chunk sequence should be present (space after colon, then quote in next chunk)
let joined = chunks.join("|");
assert!(joined.contains(r#" |"#)); // The bug happens at {"name": " and then "
}
}
...@@ -126,28 +126,6 @@ fn test_glm4_format_detection() { ...@@ -126,28 +126,6 @@ fn test_glm4_format_detection() {
assert!(!parser.has_tool_markers("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
#[tokio::test]
async fn test_glm4_python_literal_values() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>config
<arg_key>debug</arg_key>
<arg_value>True</arg_value>
<arg_key>verbose</arg_key>
<arg_value>False</arg_value>
<arg_key>optional</arg_key>
<arg_value>None</arg_value>
</tool_call>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["debug"], true);
assert_eq!(args["verbose"], false);
assert_eq!(args["optional"], serde_json::Value::Null);
}
#[tokio::test] #[tokio::test]
async fn test_python_literals() { async fn test_python_literals() {
let parser = Glm4MoeParser::new(); let parser = Glm4MoeParser::new();
...@@ -172,7 +150,7 @@ async fn test_python_literals() { ...@@ -172,7 +150,7 @@ async fn test_python_literals() {
} }
#[tokio::test] #[tokio::test]
async fn test_nested_values() { async fn test_glm4_nested_json_in_arg_values() {
let parser = Glm4MoeParser::new(); let parser = Glm4MoeParser::new();
let input = r#"<tool_call>process let input = r#"<tool_call>process
......
//! Partial JSON Parser Tests
//!
//! Tests for the partial JSON parser with allow_partial_strings flag behavior
use sglang_router_rs::tool_parser::partial_json::PartialJson;
#[test]
fn test_partial_string_flag_disallows_incomplete_strings() {
// Test case from the bug report: {"name": "
// With allow_partial_strings=false, should return {} (stop before incomplete string)
let parser = PartialJson::new(32, true);
let input = r#"{"name": ""#;
let result = parser.parse_value(input, false);
assert!(result.is_ok());
let (obj, consumed) = result.unwrap();
// Should parse just the opening brace and stop at the incomplete string
assert!(obj.is_object());
let obj_map = obj.as_object().unwrap();
// Should have empty object (stopped before parsing incomplete "name" key)
assert!(
obj_map.is_empty() || !obj_map.contains_key("name"),
"Should not parse incomplete string key, got: {:?}",
obj_map
);
// Should consume characters up to the incomplete string
assert!(consumed <= input.len());
}
#[test]
fn test_partial_string_flag_allows_incomplete_strings() {
// Test case: {"name": "
// With allow_partial_strings=true, should parse the incomplete string
let parser = PartialJson::new(32, true);
let input = r#"{"name": ""#;
let result = parser.parse_value(input, true);
assert!(result.is_ok());
let (obj, consumed) = result.unwrap();
// Should parse the object with incomplete string value
assert!(obj.is_object());
let obj_map = obj.as_object().unwrap();
// With allow_partial_strings=true, should parse "name" key with empty string value
assert!(
obj_map.contains_key("name"),
"Should parse incomplete string with allow_partial_strings=true"
);
assert_eq!(consumed, input.len());
}
#[test]
fn test_partial_string_flag_complete_json() {
// Test case: {"name": "test"}
// Both flags should parse complete JSON the same way
let input = r#"{"name": "test"}"#;
let parser = PartialJson::new(32, true);
let result1 = parser.parse_value(input, false);
assert!(result1.is_ok());
let (obj1, consumed1) = result1.unwrap();
let result2 = parser.parse_value(input, true);
assert!(result2.is_ok());
let (obj2, consumed2) = result2.unwrap();
// Both should parse the same complete JSON
assert_eq!(obj1, obj2);
assert_eq!(consumed1, consumed2);
assert_eq!(consumed1, input.len());
// Check the parsed value
assert!(obj1.is_object());
let obj_map = obj1.as_object().unwrap();
assert_eq!(obj_map.get("name").and_then(|v| v.as_str()), Some("test"));
}
#[test]
fn test_backward_compatibility_default() {
// Test that default PartialJson still allows partial strings (backward compatible)
let parser = PartialJson::default();
let input = r#"{"name": ""#;
let result = parser.parse_value(input, true);
assert!(result.is_ok());
let (obj, _) = result.unwrap();
assert!(obj.is_object());
// Default behavior should allow partial strings
let obj_map = obj.as_object().unwrap();
assert!(
obj_map.contains_key("name"),
"Default should allow partial strings for backward compatibility"
);
}
#[test]
fn test_partial_string_in_nested_object() {
// Test case: {"tool": {"name": "
let parser = PartialJson::new(32, true);
let input = r#"{"tool": {"name": ""#;
let result = parser.parse_value(input, false);
assert!(result.is_ok());
let (obj, _) = result.unwrap();
assert!(obj.is_object());
// With allow_partial_strings=false, should stop before incomplete nested string
let obj_map = obj.as_object().unwrap();
if let Some(tool) = obj_map.get("tool") {
if let Some(tool_map) = tool.as_object() {
assert!(
!tool_map.contains_key("name")
|| tool_map.get("name").and_then(|v| v.as_str()).is_none(),
"Should not parse incomplete nested string"
);
}
}
}
#[test]
fn test_bug_fix_exact_scenario() {
// This test verifies the exact bug scenario from the issue:
// buffer = "{\"name\": \""
// flags = Allow.ALL & ~Allow.STR
// Python returns: Parsed object: {}, consumed length: 10
let parser = PartialJson::new(32, true);
let input = r#"{"name": ""#;
let result = parser.parse_value(input, false);
assert!(result.is_ok());
let (obj, consumed) = result.unwrap();
// Should return empty object (not {"name": null} or {"name": ""})
assert!(obj.is_object());
let obj_map = obj.as_object().unwrap();
assert!(
obj_map.is_empty(),
"Expected empty object, got: {:?}. This matches Python behavior with Allow.ALL & ~Allow.STR",
obj_map
);
// Should consume all characters (10 bytes)
assert_eq!(consumed, 10, "Should consume all 10 characters");
}
//! Streaming Parser Tests //! Realistic Streaming Parser Tests
//! //!
//! Tests for incremental/streaming parsing capabilities across all parsers //! Tests incremental parsing with realistic char-level chunks (2-5 chars)
//! that simulate how LLM tokens actually arrive.
//!
//! These tests are designed to catch bugs like `{"name": "` being parsed
//! as an empty tool name.
use sglang_router_rs::tool_parser::{ use sglang_router_rs::tool_parser::{JsonParser, LlamaParser, QwenParser, ToolParser};
JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, ToolParser,
};
mod common; mod common;
use common::create_test_tools; use common::{create_test_tools, streaming_helpers::*};
// =============================================================================
// THE BUG SCENARIO - Most Critical Test
// =============================================================================
#[tokio::test] #[tokio::test]
async fn test_json_streaming_simple() { async fn test_json_bug_incomplete_tool_name_string() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = JsonParser::new(); let mut parser = JsonParser::new();
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; // This exact sequence triggered the bug:
// Parser receives {"name": " and must NOT parse it as empty name
let chunks = vec![
r#"{"#,
r#"""#,
r#"name"#,
r#"""#,
r#":"#,
r#" "#,
r#"""#, // ← Critical moment: parser has {"name": "
// At this point, partial_json should NOT allow incomplete strings
// when current_tool_name_sent=false
r#"search"#, // Use valid tool name from create_test_tools()
r#"""#,
r#", "#,
r#"""#,
r#"arguments"#,
r#"""#,
r#": {"#,
r#"""#,
r#"query"#,
r#"""#,
r#": "#,
r#"""#,
r#"rust programming"#,
r#"""#,
r#"}}"#,
];
let mut got_tool_name = false;
let mut saw_empty_name = false;
for chunk in chunks.iter() {
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
let result = parser.parse_incremental(full_json, &tools).await.unwrap(); for call in result.calls {
if let Some(name) = &call.name {
if name.is_empty() {
saw_empty_name = true;
}
if name == "search" {
got_tool_name = true;
}
}
}
}
assert!(!result.calls.is_empty(), "Should have parsed a tool call"); assert!(
assert_eq!(result.calls[0].name, Some("get_weather".to_string())); !saw_empty_name,
"Parser should NEVER return empty tool name"
);
assert!(got_tool_name, "Should have parsed tool name correctly");
} }
// =============================================================================
// JSON PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test] #[tokio::test]
async fn test_json_streaming_array() { async fn test_json_realistic_chunks_simple_tool() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = JsonParser::new(); let mut parser = JsonParser::new();
let chunks = vec![ let input = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
r#"["#, let chunks = create_realistic_chunks(input);
r#"{"name": "tool1", "#,
r#""arguments": {}}, "#,
r#"{"name": "tool2", "#,
r#""arguments": {"x": 1"#,
r#"}}]"#,
];
let mut tool_count = 0; assert!(chunks.len() > 10, "Should have many small chunks");
let mut got_tool_name = false;
for chunk in chunks { for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap(); let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls { for call in result.calls {
if call.name.is_some() { if let Some(name) = call.name {
tool_count += 1; assert_eq!(name, "get_weather");
got_tool_name = true;
} }
} }
} }
// Current implementation may handle this differently assert!(got_tool_name, "Should have parsed tool name");
assert!(tool_count <= 2, "Should parse at most 2 tools");
} }
#[tokio::test] #[tokio::test]
async fn test_mistral_streaming() { async fn test_json_strategic_chunks_with_quotes() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = JsonParser::new();
let mut parser = MistralParser::new(); let input = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
let chunks = create_strategic_chunks(input);
let chunks = vec![ // Strategic chunks break after quotes and colons
r#"Here is the result: "#, assert!(chunks.iter().any(|c| c.ends_with('"')));
r#"[TOOL_CALLS] ["#,
r#"{"name": "#,
r#""search", "#,
r#""arguments": "#,
r#"{"query": "#,
r#""rust lang""#,
r#"}}]"#,
];
let mut got_tool_name = false; let mut got_tool_name = false;
for chunk in chunks { for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap(); let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls { for call in result.calls {
if let Some(name) = call.name { if call.name.is_some() {
assert_eq!(name, "search");
got_tool_name = true; got_tool_name = true;
} }
} }
} }
assert!(got_tool_name, "Should have found tool name"); assert!(got_tool_name, "Should have parsed tool name");
} }
#[tokio::test] #[tokio::test]
async fn test_pythonic_streaming() { async fn test_json_incremental_arguments_streaming() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = JsonParser::new();
let mut parser = PythonicParser::new(); let input = r#"{"name": "search", "arguments": {"query": "test", "limit": 10}}"#;
let chunks = create_realistic_chunks(input);
let full_input = r#"[get_weather(city="London", units="celsius")]"#; let mut tool_name_sent = false;
let mut got_arguments = false;
let result = parser.parse_incremental(full_input, &tools).await.unwrap(); for chunk in chunks {
let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls {
if call.name.is_some() {
tool_name_sent = true;
}
if tool_name_sent && !call.parameters.is_empty() {
got_arguments = true;
}
}
}
assert!(!result.calls.is_empty(), "Should have parsed a tool call"); assert!(tool_name_sent, "Should have sent tool name");
assert_eq!(result.calls[0].name, Some("get_weather".to_string())); assert!(got_arguments, "Should have sent arguments");
let args: serde_json::Value = serde_json::from_str(&result.calls[0].parameters).unwrap();
assert_eq!(args["city"], "London");
} }
// =============================================================================
// LLAMA PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test] #[tokio::test]
async fn test_llama_streaming_with_python_tag() { async fn test_llama_realistic_chunks_with_python_tag() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = LlamaParser::new(); let mut parser = LlamaParser::new();
let chunks = vec![ let input = r#"<|python_tag|>{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
r#"Let me help. "#, let chunks = create_realistic_chunks(input);
r#"<|python"#,
r#"_tag|>"#, assert!(chunks.len() > 15, "Should have many small chunks");
r#"{"name": "#,
r#""calculate", "#,
r#""arguments": "#,
r#"{"x": 10}"#,
r#"}"#,
];
let mut got_tool_name = false; let mut got_tool_name = false;
for chunk in chunks { for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap(); let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
for call in result.calls { for call in result.calls {
if let Some(name) = call.name { if let Some(name) = call.name {
assert_eq!(name, "calculate"); assert_eq!(name, "calculate");
...@@ -130,185 +181,142 @@ async fn test_llama_streaming_with_python_tag() { ...@@ -130,185 +181,142 @@ async fn test_llama_streaming_with_python_tag() {
} }
} }
assert!(got_tool_name, "Should have found tool name"); assert!(got_tool_name, "Should have parsed tool name");
}
#[tokio::test]
async fn test_qwen_streaming() {
let tools = create_test_tools();
let mut parser = QwenParser::new();
// Note: Parser expects newline after both tags
let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";
let result = parser.parse_incremental(full_input, &tools).await.unwrap();
assert!(!result.calls.is_empty(), "Should have parsed a tool call");
assert_eq!(result.calls[0].name, Some("translate".to_string()));
} }
#[tokio::test] #[tokio::test]
async fn test_streaming_incomplete_stays_incomplete() { async fn test_llama_python_tag_arrives_in_parts() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = LlamaParser::new();
let mut parser = JsonParser::new(); // Python tag itself arrives in small chunks
let chunks = vec![
"<|p", "yth", "on_", "tag", "|>{", r#"""#, "na", r#"me""#, ": ", r#"""#, "sea", "rch",
r#"""#, ", ", r#"""#, "par", "ame", "ter", "s", r#"""#, ": {", r#"""#, "q", r#"""#, ": ",
r#"""#, "tes", "t", r#"""#, "}}",
];
let chunks = vec![r#"{"na"#, r#"me": "#]; let mut got_tool_name = false;
for chunk in chunks { for chunk in chunks {
let result = parser.parse_incremental(chunk, &tools).await.unwrap(); let result = parser.parse_incremental(chunk, &tools).await.unwrap();
assert!( for call in result.calls {
result.calls.is_empty(), if let Some(name) = call.name {
"Should return empty calls for partial JSON, got: {:?}", assert_eq!(name, "search");
result got_tool_name = true;
);
} }
}
}
assert!(got_tool_name, "Should have parsed tool name");
} }
// =============================================================================
// QWEN PARSER REALISTIC STREAMING
// =============================================================================
#[tokio::test] #[tokio::test]
async fn test_streaming_buffer_accumulation() { async fn test_qwen_realistic_chunks_with_xml_tags() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = QwenParser::new();
let mut parser = JsonParser::new(); let input = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Tokyo\"}}\n</tool_call>";
let chunks = create_realistic_chunks(input);
let result1 = parser.parse_incremental(r#"{"na"#, &tools).await.unwrap(); assert!(chunks.len() > 20, "Should have many small chunks");
assert!(result1.calls.is_empty(), "Should not parse incomplete JSON"); let mut got_tool_name = false;
let result2 = parser for chunk in chunks {
.parse_incremental(r#"me": "test", "arguments": {}}"#, &tools) let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
.await for call in result.calls {
.unwrap(); if let Some(name) = call.name {
assert_eq!(name, "get_weather");
got_tool_name = true;
}
}
}
assert!( assert!(got_tool_name, "Should have parsed tool name");
!result2.calls.is_empty(),
"Should parse complete JSON after buffering"
);
assert_eq!(result2.calls[0].name, Some("test".to_string()));
} }
#[tokio::test] #[tokio::test]
async fn test_streaming_multiple_tools_sequential() { async fn test_qwen_xml_tag_arrives_in_parts() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = QwenParser::new(); let mut parser = QwenParser::new();
let full_input = r#"<tool_call> let chunks = vec![
{"name": "tool1", "arguments": {}} "<to", "ol_", "cal", "l>\n", "{", r#"""#, "na", "me", r#"""#, ": ", r#"""#, "tra", "nsl",
</tool_call>"#; "ate", r#"""#, ", ", r#"""#, "arg", "ume", "nts", r#"""#, ": {", r#"""#, "tex", "t",
r#"""#, ": ", r#"""#, "hel", "lo", r#"""#, "}}\n", "</t", "ool", "_ca", "ll>",
let result = parser.parse_incremental(full_input, &tools).await.unwrap(); ];
assert!(!result.calls.is_empty(), "Should have parsed a tool call");
assert_eq!(result.calls[0].name, Some("tool1".to_string()));
}
#[tokio::test]
async fn test_streaming_reset_after_error() {
let tools = create_test_tools();
let mut parser1 = JsonParser::new();
let _ = parser1 let mut got_tool_name = false;
.parse_incremental(r#"{"name": invalid}"#, &tools)
.await;
// Use a new parser instance for clean state for chunk in chunks {
let mut parser2 = JsonParser::new(); let result = parser.parse_incremental(chunk, &tools).await.unwrap();
let result = parser2 for call in result.calls {
.parse_incremental(r#"{"name": "test", "arguments": {}}"#, &tools) if let Some(name) = call.name {
.await assert_eq!(name, "translate");
.unwrap(); got_tool_name = true;
}
}
}
assert!(!result.calls.is_empty(), "Should parse valid JSON"); assert!(got_tool_name, "Should have parsed tool name");
assert_eq!(result.calls[0].name, Some("test".to_string()));
} }
// =============================================================================
// EDGE CASES WITH REALISTIC CHUNKS
// =============================================================================
#[tokio::test] #[tokio::test]
async fn test_streaming_with_unicode_chunks() { async fn test_json_very_long_url_in_arguments() {
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = JsonParser::new(); let mut parser = JsonParser::new();
let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#; // Simulate long URL arriving in many chunks
let long_url = "https://example.com/very/long/path/".to_string() + &"segment/".repeat(50);
let input = format!(
r#"{{"name": "search", "arguments": {{"query": "{}"}}}}"#,
long_url
);
let chunks = create_realistic_chunks(&input);
let result = parser.parse_incremental(full_input, &tools).await.unwrap(); assert!(chunks.len() > 100, "Long URL should create many chunks");
assert!(!result.calls.is_empty(), "Should have parsed a tool call"); let mut got_tool_name = false;
// Check if we got the tool name for chunk in chunks {
if let Some(name) = &result.calls[0].name { let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
assert_eq!(name, "translate"); for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
}
} }
// In streaming mode, need to make another call to get parameters
let result2 = parser.parse_incremental("", &tools).await.unwrap();
// Parameters should be in either result.calls[1] or result2.calls[0]
let params = if result.calls.len() > 1 {
&result.calls[1].parameters
} else if !result2.calls.is_empty() {
&result2.calls[0].parameters
} else {
&result.calls[0].parameters
};
if !params.is_empty() {
let args: serde_json::Value = serde_json::from_str(params).unwrap();
assert!(args["text"].as_str().unwrap().contains("世界"));
} }
assert!(got_tool_name, "Should have parsed tool name");
} }
#[tokio::test] #[tokio::test]
async fn test_streaming_with_partial_chunks() { async fn test_json_unicode_arrives_byte_by_byte() {
let mut parser = JsonParser::new();
let tools = create_test_tools(); let tools = create_test_tools();
let mut parser = JsonParser::new();
let partial = r#"{"#; let input = r#"{"name": "search", "arguments": {"query": "Hello 世界 🌍"}}"#;
let result = parser.parse_incremental(partial, &tools).await.unwrap(); let chunks = create_realistic_chunks(input);
assert!(
result.calls.is_empty(),
"Should return empty calls for just opening brace"
);
let mut parser2 = JsonParser::new();
let complete = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
let result = parser2.parse_incremental(complete, &tools).await.unwrap();
assert!(
!result.calls.is_empty(),
"Expected tool call for complete JSON"
);
assert_eq!(result.calls[0].name.as_ref().unwrap(), "get_weather");
// In streaming mode, need to make another call to get parameters
let result2 = parser2.parse_incremental("", &tools).await.unwrap();
// Parameters should be in either result.calls[1] or result2.calls[0] let mut got_tool_name = false;
let params = if result.calls.len() > 1 {
&result.calls[1].parameters
} else if !result2.calls.is_empty() {
&result2.calls[0].parameters
} else {
&result.calls[0].parameters
};
if !params.is_empty() { for chunk in chunks {
let args: serde_json::Value = serde_json::from_str(params).unwrap(); let result = parser.parse_incremental(&chunk, &tools).await.unwrap();
assert_eq!(args["location"], "SF"); for call in result.calls {
if call.name.is_some() {
got_tool_name = true;
} }
// The PartialJson parser can complete partial JSON by filling in missing values
let mut parser3 = JsonParser::new();
let partial_with_name = r#"{"name": "test", "argum"#;
let result = parser3
.parse_incremental(partial_with_name, &tools)
.await
.unwrap();
// Parser behavior may vary - either complete with partial data or wait for more
if !result.calls.is_empty() {
assert_eq!(result.calls[0].name.as_ref().unwrap(), "test");
} }
}
assert!(got_tool_name, "Should have parsed with unicode");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment