Unverified Commit 1ed877fa authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: add mistral and phi4 tool parser (#2510)

parent c5d9d267
...@@ -74,10 +74,10 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String { ...@@ -74,10 +74,10 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
continue; continue;
} }
// Only consider segments that start like JSON // Only consider segments that start like JSON
if s.starts_with('{') || s.starts_with('[') { if s.starts_with('{') {
// Trim trailing non-JSON by cutting at the last closing brace/bracket // Trim trailing non-JSON by cutting at the last closing brace/bracket
if let Some(pos) = s.rfind(['}', ']']) { if let Some(pos) = s.rfind('}') {
let candidate = &s[..=pos]; let candidate = &s[..=pos].trim();
// Keep only valid JSON candidates // Keep only valid JSON candidates
if serde_json::from_str::<serde_json::Value>(candidate).is_ok() { if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
items.push(candidate.to_string()); items.push(candidate.to_string());
...@@ -85,10 +85,16 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String { ...@@ -85,10 +85,16 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
} }
} }
} }
if items.is_empty() { if items.is_empty() {
// Remove everything up to and including the first occurrence of the start token
if let Some(idx) = input.find(start_token) {
let rest = &input[idx + start_token.len()..];
return rest.trim_start().to_string();
} else {
// Shouldn't happen because we checked contains() above, but be defensive
return input.to_string(); return input.to_string();
} }
}
format!("[{}]", items.join(",")) format!("[{}]", items.join(","))
} }
...@@ -167,7 +173,7 @@ pub fn try_tool_call_parse_json( ...@@ -167,7 +173,7 @@ pub fn try_tool_call_parse_json(
}; };
} }
// Convert json to &str if it's a String, otherwise keep as &str // Convert json (String) to &str
let json = json.as_str(); let json = json.as_str();
// Anonymous function to attempt deserialization into a known representation // Anonymous function to attempt deserialization into a known representation
......
...@@ -99,6 +99,28 @@ impl ToolCallConfig { ...@@ -99,6 +99,28 @@ impl ToolCallConfig {
}, },
} }
} }
pub fn mistral() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
pub fn phi4() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
} }
/// Configuration for parsing tool calls with different formats /// Configuration for parsing tool calls with different formats
...@@ -142,6 +164,8 @@ pub fn detect_and_parse_tool_call( ...@@ -142,6 +164,8 @@ pub fn detect_and_parse_tool_call(
parser_map.insert("hermes", ToolCallConfig::hermes()); parser_map.insert("hermes", ToolCallConfig::hermes());
parser_map.insert("nemotron_deci", ToolCallConfig::nemotron_deci()); parser_map.insert("nemotron_deci", ToolCallConfig::nemotron_deci());
parser_map.insert("llama3_json", ToolCallConfig::llama3_json()); parser_map.insert("llama3_json", ToolCallConfig::llama3_json());
parser_map.insert("mistral", ToolCallConfig::mistral());
parser_map.insert("phi4", ToolCallConfig::phi4());
parser_map.insert("default", ToolCallConfig::default()); // Add default key parser_map.insert("default", ToolCallConfig::default()); // Add default key
// Handle None or empty string by defaulting to "default" // Handle None or empty string by defaulting to "default"
...@@ -305,6 +329,36 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -305,6 +329,36 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_nvidia_llama3_nemotron_super_49b_with_function_array_with_new_lines() {
let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
</think>
<TOOLCALL>
[{"name": "get_weather",
"arguments": {"location": "San Francisco, CA",
"unit": "fahrenheit"}},
{"name": "get_weather",
"arguments":
{"location": "New York, NY",
"unit": "fahrenheit"}}]
</TOOLCALL>
"#;
let config = ToolCallConfig::nemotron_deci();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
fn test_qwen_qwq_32b_simple() { fn test_qwen_qwq_32b_simple() {
let input = r#"<tool_call> let input = r#"<tool_call>
...@@ -356,6 +410,33 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -356,6 +410,33 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_qwen_qwq_32b_multiple_tool_calls_with_new_lines() {
let input = r#"<tool_call>
{"name": "get_weather",
"arguments": {"location": "San Francisco, CA",
"unit": "fahrenheit"}}
</tool_call>
<tool_call>
{"name": "get_weather", "arguments":
{"location": "New York, NY", "unit":
"fahrenheit"}}
</tool_call>
"#;
let config = ToolCallConfig::hermes();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
#[ignore] #[ignore]
fn test_ibm_granite_40_tiny_preview_simple() { fn test_ibm_granite_40_tiny_preview_simple() {
...@@ -379,7 +460,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -379,7 +460,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
} }
#[test] #[test]
#[ignore]
fn test_mistralai_mistral_7b_instruct_v03_simple() { fn test_mistralai_mistral_7b_instruct_v03_simple() {
let input = r#" [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#; let input = r#" [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig { let config = ToolCallConfig {
...@@ -400,6 +480,146 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -400,6 +480,146 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_mistralai_mistral_7b_instruct_v03_simple_with_new_lines() {
let input = r#"
[{"name": "get_weather",
"arguments": {"location":
"San Francisco, CA",
"unit": "fahrenheit"}}]
"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_mistralai_mistral_7b_instruct_v03_multiple() {
let input = r#" [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_mistralai_mistral_7b_instruct_v03_multiple_with_new_lines() {
let input = r#"
[{"name": "get_weather",
"arguments": {"location": "San Francisco, CA",
"unit": "fahrenheit"}}, {"name": "get_weather", "arguments":
{"location": "New York, NY", "unit": "fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token() {
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_mistralai_mistral_7b_instruct_v03_single_with_start_tokenwith_new_lines() {
let input = r#"
[TOOL_CALLS]
[{"name": "get_weather",
"arguments": {"location":
"San Francisco, CA",
"unit": "fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_multiple() {
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_multiple_with_new_lines() {
let input = r#"
[TOOL_CALLS]
[{"name": "get_weather",
"arguments": {"location":
"San Francisco, CA",
"unit": "fahrenheit"}},
{"name": "get_weather", "arguments":
{"location": "New York, NY", "unit":
"fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let result = try_tool_call_parse(input, &config).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
fn test_meta_llama_llama31_8b_instruct_simple() { fn test_meta_llama_llama31_8b_instruct_simple() {
let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#; let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
...@@ -412,6 +632,21 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -412,6 +632,21 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_meta_llama_llama31_8b_instruct_with_new_lines() {
let input = r#"
{"name": "get_weather",
"parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag() { fn test_meta_llama_llama31_8b_instruct_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#; let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
...@@ -425,8 +660,28 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -425,8 +660,28 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
} }
#[test] #[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag_multiple() { fn test_meta_llama_llama31_8b_instruct_with_python_tag_with_new_lines() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }<|python_tag|>{ "name": "get_weather", "parameters": {"location": "New York, NY", "unit": "fahrenheit" } }"#; let input = r#"
<|python_tag|>
{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag_multiple_with_new_lines() {
let input = r#"
<|python_tag|>
{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" }}
<|python_tag|>
{"name": "get_weather", "parameters": {"location": "New York, NY", "unit": "fahrenheit" }}
"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap(); let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -564,6 +819,42 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -564,6 +819,42 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag_with_new_lines() {
let input = r#"
<|python_tag|>
{"name":
"get_weather",
"arguments":
{"location": "San Francisco, CA",
"unit": "fahrenheit" }}
"#;
let result = detect_and_parse_tool_call(input, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag_multiple_with_new_lines(
) {
let input = r#"
{"name": "get_weather", "arguments":
{"location": "San Francisco, CA",
"unit": "fahrenheit" }}
"#;
let result = detect_and_parse_tool_call(input, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() { fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() {
let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#; let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
...@@ -575,4 +866,59 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -575,4 +866,59 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_phi4_single_function_call() {
let input =
r#"functools[{"name": "get_country_capital", "arguments": {"country": "Poland"}}]"#;
let result = detect_and_parse_tool_call(input, Some("phi4")).unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_country_capital");
assert_eq!(args["country"], "Poland");
}
#[test]
fn test_phi4_multiple_function_calls_simple_arguments() {
let input = r#"functools[
{"name": "get_country_capital", "arguments": {"country": "Poland"}},
{"name": "get_population", "arguments": {"city": "Warsaw"}}
]"#;
let result = detect_and_parse_tool_call(input, Some("phi4")).unwrap();
assert_eq!(result.len(), 2);
let (name1, args1) = extract_name_and_args(result[0].clone());
assert_eq!(name1, "get_country_capital");
assert_eq!(args1["country"], "Poland");
let (name2, args2) = extract_name_and_args(result[1].clone());
assert_eq!(name2, "get_population");
assert_eq!(args2["city"], "Warsaw");
}
#[test]
fn test_phi4_single_function_call_nested_json_arguments() {
let input = r#"functools[{"name": "get_weather_forecast", "arguments":
{"location": {"city": "San Francisco",
"state": "CA"}, "date": "2023-10-05"}}]"#;
let result = detect_and_parse_tool_call(input, Some("phi4")).unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather_forecast");
assert_eq!(args["date"], "2023-10-05");
assert_eq!(args["location"]["city"], "San Francisco");
assert_eq!(args["location"]["state"], "CA");
}
#[test]
fn test_phi4_function_call_with_parameters_instead_of_arguments() {
let input = r#"functools[{"name": "calculate_distance",
"parameters": {"from": "New York", "to": "Los Angeles"}}]"#;
let result = detect_and_parse_tool_call(input, Some("phi4")).unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "calculate_distance");
assert_eq!(args["from"], "New York");
assert_eq!(args["to"], "Los Angeles");
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment