Unverified Commit af4ab656 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][tool call] Improve normal content extraction and error handling (non-stream) (#11050)

parent 11965b0d
...@@ -50,52 +50,58 @@ impl DeepSeekParser { ...@@ -50,52 +50,58 @@ impl DeepSeekParser {
text.contains("<|tool▁calls▁begin|>") text.contains("<|tool▁calls▁begin|>")
} }
/// Parse a single tool call block /// Parse a single tool call block - throws error if parsing fails
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> { fn parse_tool_call(&self, block: &str) -> ToolParserResult<ToolCall> {
if let Some(captures) = self.func_detail_extractor.captures(block) { let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
// Get function type (should be "function") ToolParserError::ParsingFailed("Failed to match tool call pattern".to_string())
let func_type = captures.get(1).map_or("", |m| m.as_str()); })?;
if func_type != "function" {
return Ok(None); // Get function type (should be "function")
} let func_type = captures.get(1).map_or("", |m| m.as_str());
if func_type != "function" {
return Err(ToolParserError::ParsingFailed(format!(
"Invalid function type: {}",
func_type
)));
}
// Get function name // Get function name
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim(); let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
if func_name.is_empty() {
// Get JSON arguments return Err(ToolParserError::ParsingFailed(
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim(); "Empty function name".to_string(),
));
// Parse JSON arguments
match serde_json::from_str::<Value>(json_args) {
Ok(value) => {
// Create arguments object
let args = if value.is_object() {
value
} else {
// If not an object, wrap it
serde_json::json!({ "value": value })
};
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments,
},
}))
}
Err(_) => Ok(None),
}
} else {
Ok(None)
} }
// Get JSON arguments
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
// Parse JSON arguments
let value = serde_json::from_str::<Value>(json_args)
.map_err(|e| ToolParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
// Create arguments object
let args = if value.is_object() {
value
} else {
// If not an object, wrap it
serde_json::json!({ "value": value })
};
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments,
},
})
} }
} }
...@@ -108,39 +114,30 @@ impl Default for DeepSeekParser { ...@@ -108,39 +114,30 @@ impl Default for DeepSeekParser {
#[async_trait] #[async_trait]
impl ToolParser for DeepSeekParser { impl ToolParser for DeepSeekParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
// Check if text contains DeepSeek format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
// Collect matches with positions and parse tools in one pass // Find where tool calls begin
let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect(); let idx = text.find("<|tool▁calls▁begin|>").unwrap();
let mut tools = Vec::new(); let normal_text = text[..idx].to_string();
for mat in matches.iter() { // Try to extract tool calls, log warnings for failures
if let Some(tool) = self.parse_tool_call(mat.as_str())? { let mut tools = Vec::new();
tools.push(tool); for mat in self.tool_call_extractor.find_iter(text) {
match self.parse_tool_call(mat.as_str()) {
Ok(tool) => tools.push(tool),
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
}
} }
} }
// Extract normal text using first and last match positions // If no tools were successfully parsed despite having markers, return entire text as fallback
let normal_text = if tools.is_empty() || matches.is_empty() { if tools.is_empty() {
text.to_string() return Ok((text.to_string(), vec![]));
} else { }
let first_start = matches[0].start();
let last_end = matches.last().unwrap().end();
let before = if first_start > 0 {
&text[..first_start]
} else {
""
};
let after = if last_end < text.len() {
&text[last_end..]
} else {
""
};
format!("{}{}", before, after)
};
Ok((normal_text, tools)) Ok((normal_text, tools))
} }
...@@ -185,11 +182,16 @@ impl ToolParser for DeepSeekParser { ...@@ -185,11 +182,16 @@ impl ToolParser for DeepSeekParser {
// Extract and parse the complete tool call // Extract and parse the complete tool call
let tool_call_text = &state.buffer[call_start_abs..call_end_abs]; let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
if let Some(tool) = self.parse_tool_call(tool_call_text)? { match self.parse_tool_call(tool_call_text) {
// Remove the processed part from buffer Ok(tool) => {
state.buffer.drain(..call_end_abs); // Remove the processed part from buffer
state.buffer.drain(..call_end_abs);
return Ok(StreamResult::ToolComplete(tool)); return Ok(StreamResult::ToolComplete(tool));
}
Err(_) => {
// Parsing failed, skip this tool call
state.buffer.drain(..call_end_abs);
}
} }
} else { } else {
// Tool call not complete yet, try to extract partial info // Tool call not complete yet, try to extract partial info
...@@ -248,51 +250,3 @@ impl ToolParser for DeepSeekParser { ...@@ -248,51 +250,3 @@ impl ToolParser for DeepSeekParser {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_deepseek_single_tool() {
let parser = DeepSeekParser::new();
let input = r#"Some text
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo", "units": "celsius"}
```<|tool▁call▁end|><|tool▁calls▁end|>More text"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Tokyo"));
}
#[tokio::test]
async fn test_parse_deepseek_multiple_tools() {
let parser = DeepSeekParser::new();
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo"}
```<|tool▁call▁end|>
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Paris"}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(tools[1].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Tokyo"));
assert!(tools[1].function.arguments.contains("Paris"));
}
#[test]
fn test_detect_format() {
let parser = DeepSeekParser::new();
assert!(parser.detect_format("<|tool▁calls▁begin|>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}
...@@ -136,34 +136,27 @@ impl ToolParser for Glm4MoeParser { ...@@ -136,34 +136,27 @@ impl ToolParser for Glm4MoeParser {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
// Collect matches with positions and parse tools in one pass // Find where tool calls begin
let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect(); let idx = text.find("<tool_call>").unwrap();
let mut tools = Vec::new(); let normal_text = text[..idx].to_string();
for mat in matches.iter() { // Extract tool calls
if let Some(tool) = self.parse_tool_call(mat.as_str())? { let mut tools = Vec::new();
tools.push(tool); for mat in self.tool_call_extractor.find_iter(text) {
match self.parse_tool_call(mat.as_str()) {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
}
} }
} }
// Extract normal text using first and last match positions // If no tools were successfully parsed despite having markers, return entire text as fallback
let normal_text = if tools.is_empty() { if tools.is_empty() {
text.to_string() return Ok((text.to_string(), vec![]));
} else { }
let first_start = matches[0].start();
let last_end = matches.last().unwrap().end();
let before = if first_start > 0 {
&text[..first_start]
} else {
""
};
let after = if last_end < text.len() {
&text[last_end..]
} else {
""
};
format!("{}{}", before, after)
};
Ok((normal_text, tools)) Ok((normal_text, tools))
} }
...@@ -247,80 +240,3 @@ impl ToolParser for Glm4MoeParser { ...@@ -247,80 +240,3 @@ impl ToolParser for Glm4MoeParser {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_glm4_single_tool() {
let parser = Glm4MoeParser::new();
let input = r#"Some text
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2024-06-27</arg_value>
</tool_call>More text"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Beijing"));
assert!(tools[0].function.arguments.contains("2024-06-27"));
assert_eq!(normal_text, "Some text\nMore text"); // Text before and after tool call
}
#[tokio::test]
async fn test_parse_glm4_multiple_tools() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Shanghai</arg_value>
</tool_call>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(tools[1].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Beijing"));
assert!(tools[1].function.arguments.contains("Shanghai"));
assert_eq!(normal_text, ""); // Pure tool calls, no normal text
}
#[tokio::test]
async fn test_parse_glm4_mixed_types() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>process_data
<arg_key>count</arg_key>
<arg_value>42</arg_value>
<arg_key>active</arg_key>
<arg_value>true</arg_value>
<arg_key>name</arg_key>
<arg_value>test</arg_value>
</tool_call>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(normal_text, ""); // Pure tool call, no normal text
assert_eq!(tools[0].function.name, "process_data");
// Parse arguments to check types
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["count"], 42);
assert_eq!(args["active"], true);
assert_eq!(args["name"], "test");
}
#[test]
fn test_detect_format() {
let parser = Glm4MoeParser::new();
assert!(parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}
...@@ -227,66 +227,3 @@ impl ToolParser for GptOssParser { ...@@ -227,66 +227,3 @@ impl ToolParser for GptOssParser {
self.has_tool_markers(text) || text.contains("<|channel|>commentary") self.has_tool_markers(text) || text.contains("<|channel|>commentary")
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_gpt_oss_single_tool() {
let parser = GptOssParser::new();
let input = r#"Some text
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|>
More text"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("San Francisco"));
}
#[tokio::test]
async fn test_parse_gpt_oss_multiple_tools() {
let parser = GptOssParser::new();
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(tools[1].function.name, "search");
assert!(tools[0].function.arguments.contains("Paris"));
assert!(tools[1].function.arguments.contains("Paris tourism"));
}
#[tokio::test]
async fn test_parse_gpt_oss_with_prefix() {
let parser = GptOssParser::new();
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "test");
}
#[tokio::test]
async fn test_parse_gpt_oss_empty_args() {
let parser = GptOssParser::new();
let input =
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_time");
assert_eq!(tools[0].function.arguments, "{}");
}
#[test]
fn test_detect_format() {
let parser = GptOssParser::new();
assert!(parser.detect_format("<|channel|>commentary to="));
assert!(parser.detect_format("<|channel|>commentary"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}
...@@ -615,155 +615,3 @@ impl ToolParser for JsonParser { ...@@ -615,155 +615,3 @@ impl ToolParser for JsonParser {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_single_tool_call() {
let parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "get_weather");
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
}
#[tokio::test]
async fn test_extract_json_with_normal_text() {
let parser = JsonParser::new();
// Test extraction of JSON from mixed text
let input =
r#"Here is some text before {"name": "test", "arguments": {}} and some text after."#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "test");
assert_eq!(
normal_text,
"Here is some text before and some text after."
);
}
#[tokio::test]
async fn test_extract_json_array_with_normal_text() {
let parser = JsonParser::new();
// Test extraction of JSON array from mixed text
let input = r#"Prefix text [{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}] suffix text"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0].function.name, "func1");
assert_eq!(tool_calls[1].function.name, "func2");
assert_eq!(normal_text, "Prefix text suffix text");
}
#[tokio::test]
async fn test_parse_multiple_tool_calls() {
let parser = JsonParser::new();
let input = r#"[
{"name": "get_weather", "arguments": {"location": "SF"}},
{"name": "search", "arguments": {"query": "news"}}
]"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0].function.name, "get_weather");
assert_eq!(tool_calls[1].function.name, "search");
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
}
#[tokio::test]
async fn test_parse_with_parameters_key() {
let parser = JsonParser::new();
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "calculate");
assert!(tool_calls[0].function.arguments.contains("10"));
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
}
#[tokio::test]
async fn test_parse_with_wrapper_tokens() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "test");
assert_eq!(normal_text, ""); // Wrapper tokens with no extra text
}
#[tokio::test]
async fn test_parse_with_start_token_invalid_json() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<|python_tag|>".to_string()],
end_tokens: vec!["".to_string()],
separator: ";".to_string(),
});
let input = r#"Hello world <|python_tag|>this is not valid json at all"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 0);
assert_eq!(normal_text, input); // Should return entire original text when JSON parsing fails
}
#[tokio::test]
async fn test_parse_with_normal_text() {
let parser = JsonParser::new();
let input = r#"Here is the weather data: {"name": "get_weather", "arguments": {"location": "SF"}} Let me know if you need more info."#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "get_weather");
assert_eq!(
normal_text,
"Here is the weather data: Let me know if you need more info."
); // Normal text is now extracted when JSON is found in mixed content
}
#[test]
fn test_detect_format() {
let parser = JsonParser::new();
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#));
}
#[tokio::test]
async fn test_streaming_parse() {
// Just verify that streaming eventually produces a complete tool call
let parser = JsonParser::new();
let mut state = ParseState::new();
// Send complete JSON in one go
// TODO simplified version, address more complex version
let full_json = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
let result = parser
.parse_incremental(full_json, &mut state)
.await
.unwrap();
// Should get a complete tool immediately with complete JSON
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
assert!(tool.function.arguments.contains("SF"));
}
_ => panic!("Expected ToolComplete for complete JSON input"),
}
}
}
...@@ -80,17 +80,17 @@ impl Default for KimiK2Parser { ...@@ -80,17 +80,17 @@ impl Default for KimiK2Parser {
#[async_trait] #[async_trait]
impl ToolParser for KimiK2Parser { impl ToolParser for KimiK2Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
// Check if text contains Kimi K2 format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
// Collect matches with positions and parse tools in one pass // Find where tool calls begin
let matches: Vec<_> = self.tool_call_extractor.captures_iter(text).collect(); let idx = text.find("<|tool_calls_section_begin|>").unwrap();
let mut tools = Vec::new(); let normal_text = text[..idx].to_string();
// Extract all tool calls using collected matches // Try to extract tool calls
for captures in matches.iter() { let mut tools = Vec::new();
for captures in self.tool_call_extractor.captures_iter(text) {
if let (Some(id_match), Some(args_match)) = ( if let (Some(id_match), Some(args_match)) = (
captures.name("tool_call_id"), captures.name("tool_call_id"),
captures.name("function_arguments"), captures.name("function_arguments"),
...@@ -100,42 +100,41 @@ impl ToolParser for KimiK2Parser { ...@@ -100,42 +100,41 @@ impl ToolParser for KimiK2Parser {
// Parse function ID // Parse function ID
if let Some((func_name, _index)) = self.parse_function_id(function_id) { if let Some((func_name, _index)) = self.parse_function_id(function_id) {
// Validate JSON arguments // Try to parse JSON arguments
if serde_json::from_str::<serde_json::Value>(function_args).is_ok() { match serde_json::from_str::<serde_json::Value>(function_args) {
// Generate unique ID Ok(_) => {
let id = format!("kimi_call_{}", uuid::Uuid::new_v4()); // Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id, tools.push(ToolCall {
r#type: "function".to_string(), id,
function: FunctionCall { r#type: "function".to_string(),
name: func_name, function: FunctionCall {
arguments: function_args.to_string(), name: func_name,
}, arguments: function_args.to_string(),
}); },
});
}
Err(e) => {
tracing::warn!(
"Failed to parse JSON arguments for {}: {}",
func_name,
e
);
continue;
}
} }
} else {
tracing::warn!("Failed to parse function ID: {}", function_id);
continue;
} }
} }
} }
// Extract normal text using first and last match positions // If no tools were successfully parsed despite having markers, return entire text as fallback
let normal_text = if tools.is_empty() || matches.is_empty() { if tools.is_empty() {
text.to_string() return Ok((text.to_string(), vec![]));
} else { }
let first_start = matches[0].get(0).unwrap().start();
let last_end = matches.last().unwrap().get(0).unwrap().end();
let before = if first_start > 0 {
&text[..first_start]
} else {
""
};
let after = if last_end < text.len() {
&text[last_end..]
} else {
""
};
format!("{}{}", before, after)
};
Ok((normal_text, tools)) Ok((normal_text, tools))
} }
...@@ -248,57 +247,3 @@ impl ToolParser for KimiK2Parser { ...@@ -248,57 +247,3 @@ impl ToolParser for KimiK2Parser {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>") self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_kimi_single_tool() {
let parser = KimiK2Parser::new();
let input = r#"Some text
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
<|tool_calls_section_end|>More text"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Tokyo"));
}
#[tokio::test]
async fn test_parse_kimi_multiple_tools() {
let parser = KimiK2Parser::new();
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|>
<|tool_calls_section_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calculate");
}
#[tokio::test]
async fn test_parse_kimi_with_whitespace() {
let parser = KimiK2Parser::new();
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|>
<|tool_calls_section_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "test");
}
#[test]
fn test_detect_format() {
let parser = KimiK2Parser::new();
assert!(parser.detect_format("<|tool_calls_section_begin|>"));
assert!(parser.detect_format("<|tool_call_begin|>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}
...@@ -101,70 +101,3 @@ impl ToolParser for LlamaParser { ...@@ -101,70 +101,3 @@ impl ToolParser for LlamaParser {
&& (text.contains(r#""name""#) || text.contains(r#""function""#))) && (text.contains(r#""name""#) || text.contains(r#""function""#)))
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_with_python_tag() {
let parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "search");
assert!(tool_calls[0].function.arguments.contains("weather"));
assert_eq!(normal_text, ""); // Pure python_tag with JSON should have no normal text
}
#[tokio::test]
async fn test_parse_plain_json() {
let parser = LlamaParser::new();
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "calculate");
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
}
#[tokio::test]
async fn test_parse_with_text_before() {
let parser = LlamaParser::new();
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "get_time");
assert_eq!(normal_text, "Let me help you with that. ");
}
#[test]
fn test_detect_format() {
let parser = LlamaParser::new();
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
}
#[tokio::test]
async fn test_single_call_with_semicolon() {
let parser = LlamaParser::new();
// Note: Llama 3.2 doesn't handle multiple calls well
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
let (_normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
// We expect this to either parse the first JSON object or fail gracefully
// Since the semicolon makes it invalid JSON, it will likely return empty
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
// If it parses anything, it should be func1
if !tool_calls.is_empty() {
assert_eq!(tool_calls[0].function.name, "func1");
}
}
}
...@@ -175,8 +175,9 @@ impl ToolParser for MistralParser { ...@@ -175,8 +175,9 @@ impl ToolParser for MistralParser {
match self.parse_json_array(json_array) { match self.parse_json_array(json_array) {
Ok(tools) => Ok((normal_text_before, tools)), Ok(tools) => Ok((normal_text_before, tools)),
Err(_) => { Err(e) => {
// If JSON parsing fails, return the original text as normal text // If JSON parsing fails, return the original text as normal text
tracing::warn!("Failed to parse tool call: {}", e);
Ok((text.to_string(), vec![])) Ok((text.to_string(), vec![]))
} }
} }
...@@ -309,67 +310,3 @@ impl ToolParser for MistralParser { ...@@ -309,67 +310,3 @@ impl ToolParser for MistralParser {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_mistral_format() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Paris"));
}
#[tokio::test]
async fn test_parse_multiple_tools() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [
{"name": "search", "arguments": {"query": "rust programming"}},
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calculate");
}
#[tokio::test]
async fn test_nested_brackets_in_json() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "process");
// JSON serialization removes spaces, so check for [3,4] without spaces
assert!(tools[0].function.arguments.contains("[3,4]"));
}
#[tokio::test]
async fn test_escaped_quotes_in_strings() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "echo");
}
#[test]
fn test_detect_format() {
let parser = MistralParser::new();
assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
assert!(
parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
);
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
}
}
...@@ -84,8 +84,21 @@ impl ToolParser for PythonicParser { ...@@ -84,8 +84,21 @@ impl ToolParser for PythonicParser {
let cleaned = Self::strip_special_tokens(text); let cleaned = Self::strip_special_tokens(text);
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) { if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
let calls = self.parse_tool_call_block(&tool_calls_text)?; match self.parse_tool_call_block(&tool_calls_text) {
Ok((normal_text, calls)) Ok(calls) => {
if calls.is_empty() {
// No tools successfully parsed despite having markers
Ok((text.to_string(), vec![]))
} else {
Ok((normal_text, calls))
}
}
Err(e) => {
// Log warning and return entire text as fallback
tracing::warn!("Failed to parse pythonic tool calls: {}", e);
Ok((text.to_string(), vec![]))
}
}
} else { } else {
Ok((text.to_string(), vec![])) Ok((text.to_string(), vec![]))
} }
...@@ -329,84 +342,3 @@ where ...@@ -329,84 +342,3 @@ where
Value::String(value.to_string()) Value::String(value.to_string())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_single_function_call() {
let parser = PythonicParser::new();
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "search_web");
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["query"], "Rust programming");
assert_eq!(args["max_results"], 5);
}
#[tokio::test]
async fn test_multiple_function_calls() {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(tools[1].function.name, "search");
}
#[tokio::test]
async fn test_python_literals() {
let parser = PythonicParser::new();
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["flag"], true);
assert_eq!(args["disabled"], false);
assert!(args["optional"].is_null());
}
#[tokio::test]
async fn test_strip_special_tokens() {
let parser = PythonicParser::new();
let input = "<|python_start|>[call(arg=1)]<|python_end|>";
assert!(parser.detect_format(input));
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
}
#[tokio::test]
async fn test_detect_format() {
let parser = PythonicParser::new();
assert!(parser.detect_format("[foo(bar=1)]"));
assert!(!parser.detect_format("No python here"));
}
#[tokio::test]
async fn test_parse_incremental() {
let parser = PythonicParser::new();
let mut state = ParseState::new();
let chunk1 = "[call(arg=";
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
assert!(matches!(result1, StreamResult::Incomplete));
let chunk2 = "1)]";
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
match result2 {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "call");
}
other => panic!("Expected ToolComplete, got {:?}", other),
}
}
}
...@@ -134,43 +134,35 @@ impl ToolParser for QwenParser { ...@@ -134,43 +134,35 @@ impl ToolParser for QwenParser {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
// Collect matches with positions and parse tools in one pass // Find where the first tool call begins
let matches: Vec<_> = self.extractor.captures_iter(text).collect(); let idx = text.find("<tool_call>").unwrap(); // Safe because has_tool_markers checked
let mut tools = Vec::new(); let normal_text = text[..idx].to_string();
for (index, captures) in matches.iter().enumerate() { // Extract tool calls
let mut tools = Vec::new();
for (index, captures) in self.extractor.captures_iter(text).enumerate() {
if let Some(json_str) = captures.get(1) { if let Some(json_str) = captures.get(1) {
match serde_json::from_str::<Value>(json_str.as_str().trim()) { match serde_json::from_str::<Value>(json_str.as_str().trim()) {
Ok(value) => { Ok(value) => match self.parse_single_object(&value, index) {
if let Some(tool) = self.parse_single_object(&value, index)? { Ok(Some(tool)) => tools.push(tool),
tools.push(tool); Ok(None) => continue,
Err(e) => {
tracing::warn!("Failed to parse tool call: {}", e);
continue;
} }
} },
Err(_) => { Err(e) => {
// JSON parsing failed, might be incomplete tracing::warn!("Failed to parse JSON in tool call: {}", e);
continue;
} }
} }
} }
} }
// Extract normal text using first and last match positions // If no tools were successfully parsed despite having markers, return entire text as fallback
let normal_text = if tools.is_empty() { if tools.is_empty() {
text.to_string() return Ok((text.to_string(), vec![]));
} else { }
let first_start = matches[0].get(0).unwrap().start();
let last_end = matches.last().unwrap().get(0).unwrap().end();
let before = if first_start > 0 {
&text[..first_start]
} else {
""
};
let after = if last_end < text.len() {
&text[last_end..]
} else {
""
};
format!("{}{}", before, after)
};
Ok((normal_text, tools)) Ok((normal_text, tools))
} }
...@@ -299,140 +291,3 @@ impl ToolParser for QwenParser { ...@@ -299,140 +291,3 @@ impl ToolParser for QwenParser {
true true
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_qwen_format() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}}
</tool_call>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Beijing"));
assert_eq!(normal_text, ""); // Pure tool call, no normal text
}
#[tokio::test]
async fn test_parse_multiple_tools() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "search", "arguments": {"query": "rust programming"}}
</tool_call>
<tool_call>
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
</tool_call>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calculate");
assert_eq!(normal_text, ""); // Pure tool calls, no normal text
}
#[tokio::test]
async fn test_with_normal_text() {
let parser = QwenParser::new();
let input = r#"Let me help you with that.
<tool_call>
{"name": "get_info", "arguments": {"topic": "Rust"}}
</tool_call>
Here are the results."#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_info");
assert_eq!(
normal_text,
"Let me help you with that.\n\nHere are the results."
);
}
#[tokio::test]
async fn test_nested_json_structures() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{
"name": "process_data",
"arguments": {
"data": {
"nested": {
"array": [1, 2, 3],
"object": {"key": "value"}
}
}
}
}
</tool_call>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "process_data");
assert!(tools[0].function.arguments.contains("nested"));
assert_eq!(normal_text, ""); // Pure tool call, no normal text
}
#[test]
fn test_detect_format() {
let parser = QwenParser::new();
assert!(parser.detect_format(
r#"<tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#
));
assert!(parser.detect_format(
r#"Text before <tool_call>
{"name": "test", "arguments": {}}
</tool_call> text after"#
));
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
// Partial format should still be detected
assert!(parser.detect_format("<tool_call>"));
}
#[tokio::test]
async fn test_streaming_partial() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<tool_call>\n",
r#"{"name": "search","#,
r#" "arguments": {"query":"#,
r#" "rust"}}"#,
"\n</tool_call>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "search");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "search");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete); // At least one should be found
}
}
...@@ -158,46 +158,33 @@ impl Default for Step3Parser { ...@@ -158,46 +158,33 @@ impl Default for Step3Parser {
#[async_trait] #[async_trait]
impl ToolParser for Step3Parser { impl ToolParser for Step3Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> { async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
// Check if text contains Step3 format
if !self.has_tool_markers(text) { if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![])); return Ok((text.to_string(), vec![]));
} }
// Find the tool calls section // Find where tool calls begin
if let Some(start_pos) = text.find("<|tool_calls_begin|>") { let idx = text.find("<|tool_calls_begin|>").unwrap();
let search_from = start_pos + "<|tool_calls_begin|>".len(); let normal_text = text[..idx].to_string();
// Find the end of tool calls section // Extract tool calls
if let Some(end_pos) = text[search_from..].find("<|tool_calls_end|>") { let mut tools = Vec::new();
let tool_section = &text[search_from..search_from + end_pos]; for mat in self.tool_call_extractor.find_iter(text) {
let end_abs = search_from + end_pos + "<|tool_calls_end|>".len(); match self.parse_tool_call(mat.as_str()) {
Ok(Some(tool)) => tools.push(tool),
// Extract all tool call blocks Ok(None) => continue,
let mut tools = Vec::new(); Err(e) => {
for mat in self.tool_call_extractor.find_iter(tool_section) { tracing::warn!("Failed to parse tool call: {}", e);
if let Some(tool) = self.parse_tool_call(mat.as_str())? { continue;
tools.push(tool);
}
} }
// Extract normal text before start and after end
let before = if start_pos > 0 {
&text[..start_pos]
} else {
""
};
let after = if end_abs < text.len() {
&text[end_abs..]
} else {
""
};
let normal_text = format!("{}{}", before, after);
return Ok((normal_text, tools));
} }
} }
Ok((text.to_string(), vec![])) // If no tools were successfully parsed despite having markers, return entire text as fallback
if tools.is_empty() {
return Ok((text.to_string(), vec![]));
}
Ok((normal_text, tools))
} }
async fn parse_incremental( async fn parse_incremental(
...@@ -297,76 +284,3 @@ impl ToolParser for Step3Parser { ...@@ -297,76 +284,3 @@ impl ToolParser for Step3Parser {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_step3_single_tool() {
let parser = Step3Parser::new();
let input = r#"Some text
<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="get_weather">
<steptml:parameter name="location">Tokyo</steptml:parameter>
<steptml:parameter name="units">celsius</steptml:parameter>
</steptml:invoke><|tool_call_end|>
<|tool_calls_end|>More text"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
assert!(tools[0].function.arguments.contains("Tokyo"));
assert!(tools[0].function.arguments.contains("celsius"));
}
#[tokio::test]
async fn test_parse_step3_multiple_tools() {
let parser = Step3Parser::new();
let input = r#"<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
<steptml:parameter name="query">rust programming</steptml:parameter>
</steptml:invoke><|tool_call_end|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="calculate">
<steptml:parameter name="expression">2 + 2</steptml:parameter>
</steptml:invoke><|tool_call_end|>
<|tool_calls_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calculate");
}
#[tokio::test]
async fn test_parse_step3_mixed_types() {
let parser = Step3Parser::new();
let input = r#"<|tool_calls_begin|>
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="process_data">
<steptml:parameter name="count">42</steptml:parameter>
<steptml:parameter name="active">true</steptml:parameter>
<steptml:parameter name="rate">1.5</steptml:parameter>
<steptml:parameter name="name">test</steptml:parameter>
</steptml:invoke><|tool_call_end|>
<|tool_calls_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "process_data");
// Parse arguments to check types
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["count"], 42);
assert_eq!(args["active"], true);
assert_eq!(args["rate"], 1.5);
assert_eq!(args["name"], "test");
}
#[test]
fn test_detect_format() {
let parser = Step3Parser::new();
assert!(parser.detect_format("<|tool_calls_begin|>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}
...@@ -13,8 +13,9 @@ async fn test_deepseek_complete_parsing() { ...@@ -13,8 +13,9 @@ async fn test_deepseek_complete_parsing() {
```<|tool▁call▁end|><|tool▁calls▁end|> ```<|tool▁call▁end|><|tool▁calls▁end|>
The weather in Tokyo is..."#; The weather in Tokyo is..."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Let me help you with that.\n");
assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -140,25 +141,6 @@ async fn test_deepseek_malformed_json_handling() { ...@@ -140,25 +141,6 @@ async fn test_deepseek_malformed_json_handling() {
assert_eq!(tools[0].function.name, "valid"); assert_eq!(tools[0].function.name, "valid");
} }
#[tokio::test]
async fn test_normal_text_extraction() {
let parser = DeepSeekParser::new();
// Python extracts text before tool calls as normal_text
let input = r#"Let me help you with that.
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo"}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
// TODO: Verify normal text extraction when parser returns it
// In Python: normal_text = "Let me help you with that."
}
#[tokio::test] #[tokio::test]
async fn test_multiple_tool_calls() { async fn test_multiple_tool_calls() {
let parser = DeepSeekParser::new(); let parser = DeepSeekParser::new();
......
...@@ -111,19 +111,25 @@ async fn test_mistral_parser_invalid_format_returns_as_normal_text() { ...@@ -111,19 +111,25 @@ async fn test_mistral_parser_invalid_format_returns_as_normal_text() {
async fn test_deepseek_parser_invalid_format_returns_as_normal_text() { async fn test_deepseek_parser_invalid_format_returns_as_normal_text() {
let parser = DeepSeekParser::new(); let parser = DeepSeekParser::new();
// Invalid JSON after emoji marker // Invalid JSON in tool call
let input = r#"🤔[{"name": "test", "arguments": malformed}]"#; let input = r#"Some text<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test
```json
{"name": "test", "arguments": malformed}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 0); assert_eq!(tools.len(), 0);
assert_eq!(normal_text, input); // Should preserve original text when parsing fails assert_eq!(normal_text, input); // Should preserve original text when parsing fails
// Emoji but no JSON array // Missing function marker
let input = "🤔 Just thinking about this problem..."; let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>notfunction<|tool▁sep|>test
```json
{"x": 1}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 0); assert_eq!(tools.len(), 0);
assert_eq!(normal_text, input); // Should return original text assert_eq!(normal_text, input); // Should return original text when parsing fails
// No emoji marker at all // No tool markers at all
let input = "Regular response without any special markers."; let input = "Regular response without any special markers.";
let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 0); assert_eq!(tools.len(), 0);
...@@ -148,9 +154,8 @@ That's all!"#; ...@@ -148,9 +154,8 @@ That's all!"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); // Should extract the valid tool assert_eq!(tools.len(), 1); // Should extract the valid tool
assert_eq!(tools[0].function.name, "valid_tool"); assert_eq!(tools[0].function.name, "valid_tool");
// Normal text should contain the text around the valid tool call // Normal text should contain text before the first tool call
assert!(normal_text.contains("Let me help you")); assert_eq!(normal_text, "Let me help you with that.\n");
assert!(normal_text.contains("That's all!"));
} }
#[tokio::test] #[tokio::test]
...@@ -208,8 +213,8 @@ async fn test_unicode_and_special_chars_in_failed_parsing() { ...@@ -208,8 +213,8 @@ async fn test_unicode_and_special_chars_in_failed_parsing() {
</tool_call>"#; </tool_call>"#;
let (normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 0); assert_eq!(tools.len(), 0);
// Should handle Unicode properly in the fallback text // Should handle Unicode properly in the fallback text - malformed content should be preserved
assert!(!normal_text.is_empty() || normal_text == input); assert_eq!(normal_text, input);
// Special characters that might confuse parsers // Special characters that might confuse parsers
let input = r#"Response: <tool_call>{"name": "test\n\t", "arguments": {"]}"}</tool_call>"#; let input = r#"Response: <tool_call>{"name": "test\n\t", "arguments": {"]}"}</tool_call>"#;
......
...@@ -15,8 +15,9 @@ async fn test_glm4_complete_parsing() { ...@@ -15,8 +15,9 @@ async fn test_glm4_complete_parsing() {
</tool_call> </tool_call>
The weather will be..."#; The weather will be..."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Let me search for that.\n");
assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -39,8 +40,9 @@ async fn test_glm4_multiple_tools() { ...@@ -39,8 +40,9 @@ async fn test_glm4_multiple_tools() {
<arg_value>zh</arg_value> <arg_value>zh</arg_value>
</tool_call>"#; </tool_call>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "translate"); assert_eq!(tools[1].function.name, "translate");
} }
...@@ -62,8 +64,9 @@ async fn test_glm4_type_conversion() { ...@@ -62,8 +64,9 @@ async fn test_glm4_type_conversion() {
<arg_value>string value</arg_value> <arg_value>string value</arg_value>
</tool_call>"#; </tool_call>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["count"], 42); assert_eq!(args["count"], 42);
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
//! Tests for the JSON parser which handles OpenAI, Claude, and generic JSON formats //! Tests for the JSON parser which handles OpenAI, Claude, and generic JSON formats
use serde_json::json; use serde_json::json;
use sglang_router_rs::tool_parser::{JsonParser, ToolParser}; use sglang_router_rs::tool_parser::{JsonParser, TokenConfig, ToolParser};
#[tokio::test] #[tokio::test]
async fn test_simple_json_tool_call() { async fn test_simple_json_tool_call() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#; let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -21,13 +22,14 @@ async fn test_simple_json_tool_call() { ...@@ -21,13 +22,14 @@ async fn test_simple_json_tool_call() {
#[tokio::test] #[tokio::test]
async fn test_json_array_of_tools() { async fn test_json_array_of_tools() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"[ let input = r#"Hello, here are the results: [
{"name": "get_weather", "arguments": {"location": "SF"}}, {"name": "get_weather", "arguments": {"location": "SF"}},
{"name": "search", "arguments": {"query": "news"}} {"name": "search", "arguments": {"query": "news"}}
]"#; ]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "Hello, here are the results: ");
assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[0].function.name, "get_weather");
assert_eq!(tools[1].function.name, "search"); assert_eq!(tools[1].function.name, "search");
} }
...@@ -37,8 +39,9 @@ async fn test_json_with_parameters_key() { ...@@ -37,8 +39,9 @@ async fn test_json_with_parameters_key() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#; let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "calculate"); assert_eq!(tools[0].function.name, "calculate");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -51,8 +54,12 @@ async fn test_json_extraction_from_text() { ...@@ -51,8 +54,12 @@ async fn test_json_extraction_from_text() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"I'll help you with that. {"name": "search", "arguments": {"query": "rust"}} Let me search for that."#; let input = r#"I'll help you with that. {"name": "search", "arguments": {"query": "rust"}} Let me search for that."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(
normal_text,
"I'll help you with that. Let me search for that."
);
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
} }
...@@ -73,8 +80,9 @@ async fn test_json_with_nested_objects() { ...@@ -73,8 +80,9 @@ async fn test_json_with_nested_objects() {
} }
}"#; }"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "update_config"); assert_eq!(tools[0].function.name, "update_config");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -87,8 +95,9 @@ async fn test_json_with_special_characters() { ...@@ -87,8 +95,9 @@ async fn test_json_with_special_characters() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"{"name": "echo", "arguments": {"text": "Line 1\nLine 2\tTabbed", "path": "C:\\Users\\test"}}"#; let input = r#"{"name": "echo", "arguments": {"text": "Line 1\nLine 2\tTabbed", "path": "C:\\Users\\test"}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["text"], "Line 1\nLine 2\tTabbed"); assert_eq!(args["text"], "Line 1\nLine 2\tTabbed");
...@@ -100,8 +109,9 @@ async fn test_json_with_unicode() { ...@@ -100,8 +109,9 @@ async fn test_json_with_unicode() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍", "emoji": "😊"}}"#; let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍", "emoji": "😊"}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["text"], "Hello 世界 🌍"); assert_eq!(args["text"], "Hello 世界 🌍");
...@@ -113,8 +123,9 @@ async fn test_json_empty_arguments() { ...@@ -113,8 +123,9 @@ async fn test_json_empty_arguments() {
let parser = JsonParser::new(); let parser = JsonParser::new();
let input = r#"{"name": "ping", "arguments": {}}"#; let input = r#"{"name": "ping", "arguments": {}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "ping"); assert_eq!(tools[0].function.name, "ping");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -127,8 +138,12 @@ async fn test_json_invalid_format() { ...@@ -127,8 +138,12 @@ async fn test_json_invalid_format() {
// Missing closing brace // Missing closing brace
let input = r#"{"name": "test", "arguments": {"key": "value""#; let input = r#"{"name": "test", "arguments": {"key": "value""#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 0); assert_eq!(tools.len(), 0);
assert_eq!(
normal_text,
"{\"name\": \"test\", \"arguments\": {\"key\": \"value\""
);
// Not JSON at all // Not JSON at all
let input = "This is just plain text"; let input = "This is just plain text";
...@@ -145,3 +160,32 @@ async fn test_json_format_detection() { ...@@ -145,3 +160,32 @@ async fn test_json_format_detection() {
assert!(!parser.detect_format("plain text")); assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
} }
#[tokio::test]
async fn test_parse_with_wrapper_tokens() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "test");
assert_eq!(normal_text, ""); // Wrapper tokens with no extra text
}
#[tokio::test]
async fn test_parse_with_start_token_invalid_json() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<|python_tag|>".to_string()],
end_tokens: vec!["".to_string()],
separator: ";".to_string(),
});
let input = r#"Hello world <|python_tag|>this is not valid json at all"#;
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
assert_eq!(tool_calls.len(), 0);
assert_eq!(normal_text, input); // Should return entire original text when JSON parsing fails
}
...@@ -12,8 +12,9 @@ async fn test_kimik2_complete_parsing() { ...@@ -12,8 +12,9 @@ async fn test_kimik2_complete_parsing() {
<|tool_calls_section_end|> <|tool_calls_section_end|>
The weather in Tokyo is..."#; The weather in Tokyo is..."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Let me help you with that.\n");
assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -30,8 +31,9 @@ async fn test_kimik2_multiple_tools() { ...@@ -30,8 +31,9 @@ async fn test_kimik2_multiple_tools() {
<|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|> <|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|>
<|tool_calls_section_end|>"#; <|tool_calls_section_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "translate"); assert_eq!(tools[1].function.name, "translate");
} }
...@@ -44,8 +46,9 @@ async fn test_kimik2_with_whitespace() { ...@@ -44,8 +46,9 @@ async fn test_kimik2_with_whitespace() {
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|> <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
<|tool_calls_section_end|>"#; <|tool_calls_section_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "test"); assert_eq!(tools[0].function.name, "test");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -117,8 +120,9 @@ async fn test_kimik2_sequential_indices() { ...@@ -117,8 +120,9 @@ async fn test_kimik2_sequential_indices() {
<|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|> <|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|>
<|tool_calls_section_end|>"#; <|tool_calls_section_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 3); assert_eq!(tools.len(), 3);
assert_eq!(normal_text, "");
assert_eq!(tools[0].function.name, "first"); assert_eq!(tools[0].function.name, "first");
assert_eq!(tools[1].function.name, "second"); assert_eq!(tools[1].function.name, "second");
assert_eq!(tools[2].function.name, "third"); assert_eq!(tools[2].function.name, "third");
...@@ -134,12 +138,12 @@ async fn test_function_index_extraction() { ...@@ -134,12 +138,12 @@ async fn test_function_index_extraction() {
<|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|> <|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|>
<|tool_calls_section_end|>"#; <|tool_calls_section_end|>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "Text before tool calls.\n");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "calc"); assert_eq!(tools[1].function.name, "calc");
// TODO: Verify indices are preserved: 0 and 1 // TODO: Verify indices are preserved: 0 and 1
// TODO: Verify normal text = "Text before tool calls."
} }
#[tokio::test] #[tokio::test]
......
...@@ -36,8 +36,9 @@ async fn test_llama_with_text_before() { ...@@ -36,8 +36,9 @@ async fn test_llama_with_text_before() {
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#; let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Let me help you with that. ");
assert_eq!(tools[0].function.name, "get_time"); assert_eq!(tools[0].function.name, "get_time");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -99,8 +100,9 @@ async fn test_llama_invalid_json_after_tag() { ...@@ -99,8 +100,9 @@ async fn test_llama_invalid_json_after_tag() {
let parser = LlamaParser::new(); let parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": invalid}"#; let input = r#"<|python_tag|>{"name": invalid}"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 0); assert_eq!(tools.len(), 0);
assert_eq!(normal_text, "<|python_tag|>{\"name\": invalid}");
} }
#[tokio::test] #[tokio::test]
......
...@@ -11,8 +11,9 @@ async fn test_mistral_single_tool() { ...@@ -11,8 +11,9 @@ async fn test_mistral_single_tool() {
let input = r#"Let me search for that. let input = r#"Let me search for that.
[TOOL_CALLS] [{"name": "search_web", "arguments": {"query": "latest news", "max_results": 5}}]"#; [TOOL_CALLS] [{"name": "search_web", "arguments": {"query": "latest news", "max_results": 5}}]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Let me search for that.\n");
assert_eq!(tools[0].function.name, "search_web"); assert_eq!(tools[0].function.name, "search_web");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -29,8 +30,9 @@ async fn test_mistral_multiple_tools() { ...@@ -29,8 +30,9 @@ async fn test_mistral_multiple_tools() {
{"name": "search_news", "arguments": {"query": "AI developments", "limit": 10}} {"name": "search_news", "arguments": {"query": "AI developments", "limit": 10}}
]"#; ]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "I'll help you with both tasks.\n");
assert_eq!(tools[0].function.name, "get_weather"); assert_eq!(tools[0].function.name, "get_weather");
let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args0: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -47,8 +49,9 @@ async fn test_mistral_nested_json() { ...@@ -47,8 +49,9 @@ async fn test_mistral_nested_json() {
let input = r#"Processing complex data. let input = r#"Processing complex data.
[TOOL_CALLS] [{"name": "process_data", "arguments": {"config": {"nested": {"value": [1, 2, 3]}}, "enabled": true}}]"#; [TOOL_CALLS] [{"name": "process_data", "arguments": {"config": {"nested": {"value": [1, 2, 3]}}, "enabled": true}}]"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Processing complex data.\n");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
assert_eq!(args["config"]["nested"]["value"], json!([1, 2, 3])); assert_eq!(args["config"]["nested"]["value"], json!([1, 2, 3]));
...@@ -146,8 +149,9 @@ async fn test_mistral_real_world_output() { ...@@ -146,8 +149,9 @@ async fn test_mistral_real_world_output() {
Let me execute these searches for you."#; Let me execute these searches for you."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "I'll search for information about Rust programming and check the weather in San Francisco.\n\n");
assert_eq!(tools[0].function.name, "web_search"); assert_eq!(tools[0].function.name, "web_search");
assert_eq!(tools[1].function.name, "get_weather"); assert_eq!(tools[1].function.name, "get_weather");
} }
...@@ -165,8 +165,9 @@ async fn test_pythonic_real_world_llama4() { ...@@ -165,8 +165,9 @@ async fn test_pythonic_real_world_llama4() {
These functions will provide the information you need."#; These functions will provide the information you need."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 3); assert_eq!(tools.len(), 3);
assert_eq!(normal_text, "I'll help you with multiple tasks. Let me search for information and perform calculations.\n\n\n\nThese functions will provide the information you need.");
assert_eq!(tools[0].function.name, "web_search"); assert_eq!(tools[0].function.name, "web_search");
assert_eq!(tools[1].function.name, "calculate"); assert_eq!(tools[1].function.name, "calculate");
assert_eq!(tools[2].function.name, "get_weather"); assert_eq!(tools[2].function.name, "get_weather");
......
...@@ -32,8 +32,9 @@ async fn test_qwen_multiple_sequential_tools() { ...@@ -32,8 +32,9 @@ async fn test_qwen_multiple_sequential_tools() {
{"name": "translate", "arguments": {"text": "Hello", "to": "zh"}} {"name": "translate", "arguments": {"text": "Hello", "to": "zh"}}
</tool_call>"#; </tool_call>"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "Let me help you with that.\n");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "translate"); assert_eq!(tools[1].function.name, "translate");
} }
...@@ -79,8 +80,9 @@ Now I'll translate something. ...@@ -79,8 +80,9 @@ Now I'll translate something.
</tool_call> </tool_call>
Done!"#; Done!"#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(normal_text, "First, let me search for information.\n");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
assert_eq!(tools[1].function.name, "translate"); assert_eq!(tools[1].function.name, "translate");
} }
...@@ -171,8 +173,12 @@ Let me also calculate something for you: ...@@ -171,8 +173,12 @@ Let me also calculate something for you:
These tools will provide the information you need."#; These tools will provide the information you need."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 2); assert_eq!(tools.len(), 2);
assert_eq!(
normal_text,
"I'll help you search for information and perform calculations.\n\n"
);
assert_eq!(tools[0].function.name, "web_search"); assert_eq!(tools[0].function.name, "web_search");
assert_eq!(tools[1].function.name, "calculator"); assert_eq!(tools[1].function.name, "calculator");
......
...@@ -15,8 +15,9 @@ async fn test_step3_complete_parsing() { ...@@ -15,8 +15,9 @@ async fn test_step3_complete_parsing() {
<|tool_calls_end|> <|tool_calls_end|>
Here are the results..."#; Here are the results..."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Let me help you.\n");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
...@@ -174,8 +175,9 @@ async fn test_steptml_format() { ...@@ -174,8 +175,9 @@ async fn test_steptml_format() {
</steptml:invoke><|tool_call_end|> </steptml:invoke><|tool_call_end|>
<|tool_calls_end|>Text after."#; <|tool_calls_end|>Text after."#;
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap(); let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
assert_eq!(normal_text, "Text before.\n");
assert_eq!(tools[0].function.name, "search"); assert_eq!(tools[0].function.name, "search");
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment