Unverified Commit d39d676b authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: added support for pythonic tool parser (#2788)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 58d2699d
......@@ -1726,13 +1726,34 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "derive_more"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
dependencies = [
"derive_more-impl 1.0.0",
]
[[package]]
name = "derive_more"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
dependencies = [
"derive_more-impl",
"derive_more-impl 2.0.1",
]
[[package]]
name = "derive_more-impl"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
"unicode-xid",
]
[[package]]
......@@ -2044,8 +2065,10 @@ dependencies = [
"anyhow",
"dynamo-async-openai",
"lazy_static",
"num-traits",
"openai-harmony",
"regex",
"rustpython-parser",
"serde",
"serde_json",
"tracing",
......@@ -3143,6 +3166,9 @@ name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
......@@ -3773,6 +3799,18 @@ dependencies = [
"serde",
]
[[package]]
name = "is-macro"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
......@@ -3797,6 +3835,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
......@@ -3930,6 +3977,12 @@ dependencies = [
"winapi-build",
]
[[package]]
name = "lalrpop-util"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553"
[[package]]
name = "lazy_static"
version = "1.5.0"
......@@ -4148,6 +4201,64 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
[[package]]
name = "malachite"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5"
dependencies = [
"malachite-base",
"malachite-nz",
"malachite-q",
]
[[package]]
name = "malachite-base"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb"
dependencies = [
"hashbrown 0.14.5",
"itertools 0.11.0",
"libm",
"ryu",
]
[[package]]
name = "malachite-bigint"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8"
dependencies = [
"derive_more 1.0.0",
"malachite",
"num-integer",
"num-traits",
"paste",
]
[[package]]
name = "malachite-nz"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7"
dependencies = [
"itertools 0.11.0",
"libm",
"malachite-base",
]
[[package]]
name = "malachite-q"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191"
dependencies = [
"itertools 0.11.0",
"malachite-base",
"malachite-nz",
]
[[package]]
name = "malloc_buf"
version = "0.0.6"
......@@ -6589,6 +6700,63 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustpython-ast"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5"
dependencies = [
"is-macro",
"malachite-bigint",
"rustpython-parser-core",
"static_assertions",
]
[[package]]
name = "rustpython-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb"
dependencies = [
"anyhow",
"is-macro",
"itertools 0.11.0",
"lalrpop-util",
"log",
"malachite-bigint",
"num-traits",
"phf",
"phf_codegen",
"rustc-hash 1.1.0",
"rustpython-ast",
"rustpython-parser-core",
"tiny-keccak",
"unic-emoji-char",
"unic-ucd-ident",
"unicode_names2",
]
[[package]]
name = "rustpython-parser-core"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad"
dependencies = [
"is-macro",
"memchr",
"rustpython-parser-vendored",
]
[[package]]
name = "rustpython-parser-vendored"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6"
dependencies = [
"memchr",
"once_cell",
]
[[package]]
name = "rustversion"
version = "1.0.22"
......@@ -7787,6 +7955,15 @@ dependencies = [
"time-core",
]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237"
dependencies = [
"crunchy",
]
[[package]]
name = "tinystr"
version = "0.8.1"
......@@ -8484,6 +8661,58 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unic-char-property"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221"
dependencies = [
"unic-char-range",
]
[[package]]
name = "unic-char-range"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc"
[[package]]
name = "unic-common"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc"
[[package]]
name = "unic-emoji-char"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b07221e68897210270a38bde4babb655869637af0f69407f96053a34f76494d"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-ident"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e230a37c0381caa9219d67cf063aa3a375ffed5bf541a452db16e744bdab6987"
dependencies = [
"unic-char-property",
"unic-char-range",
"unic-ucd-version",
]
[[package]]
name = "unic-ucd-version"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4"
dependencies = [
"unic-common",
]
[[package]]
name = "unicase"
version = "2.8.1"
......@@ -8523,12 +8752,40 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unicode_names2"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd"
dependencies = [
"phf",
"unicode_names2_generator",
]
[[package]]
name = "unicode_names2_generator"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e"
dependencies = [
"getopts",
"log",
"phf_codegen",
"rand 0.8.5",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
......
......@@ -33,7 +33,10 @@ allow = [
"MPL-2.0",
"CDLA-Permissive-2.0",
"Zlib",
"NCSA"
"NCSA",
"LGPL-3.0",
"CC0-1.0",
"Unicode-DFS-2016"
]
# TODO exceptions
......
......@@ -35,3 +35,5 @@ uuid = { workspace = true }
regex = "1"
openai-harmony = "0.0.3"
lazy_static = "1.5.0"
rustpython-parser = "0.4.0"
num-traits = "0.2"
\ No newline at end of file
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
/// Represents the format type for tool calls
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum ToolCallParserType {
/// JSON format: `{"name": "function", "arguments": {...}}`
Json,
Pythonic,
Harmony,
/// <function_call>```typescript
/// functions.get_current_weather({"location": "Shanghai"})
/// ```
Typescript,
Xml,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
/// Start token for list of parallel tool calls (e.g., "<TOOLCALLS>")
pub parallel_tool_calls_start_tokens: Vec<String>,
/// End token for list of parallel tool calls (e.g., "</TOOLCALLS>")
pub parallel_tool_calls_end_tokens: Vec<String>,
/// Start token for individual tool calls (e.g., "<TOOLCALL>")
pub tool_call_start_tokens: Vec<String>,
/// End token for individual tool calls (e.g., "</TOOLCALL>")
pub tool_call_end_tokens: Vec<String>,
/// The key for the function name in the tool call
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "name"
pub function_name_keys: Vec<String>,
/// The key for the arguments in the tool call
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "arguments"
pub arguments_keys: Vec<String>,
}
impl Default for JsonParserConfig {
fn default() -> Self {
Self {
parallel_tool_calls_start_tokens: vec![],
parallel_tool_calls_end_tokens: vec![],
tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
function_name_keys: vec!["name".to_string()],
arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
}
}
}
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
impl Default for ToolCallConfig {
fn default() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig::default(),
}
}
}
impl ToolCallConfig {
/// Default configuration for hermes tool calls
/// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
pub fn hermes() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["\n</tool_call>".to_string()],
..Default::default()
},
}
}
/// Default configuration for nemotron tool calls
/// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
pub fn nemotron_deci() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default()
},
}
}
pub fn llama3_json() -> Self {
// <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
// or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
pub fn mistral() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
pub fn phi4() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
pub fn pythonic() -> Self {
Self {
format: ToolCallParserType::Pythonic,
json: JsonParserConfig::default(), // This is noop here, but we keep it for consistency
}
}
}
......@@ -7,7 +7,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::parsers::JsonParserConfig;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
// Same as CalledFunction with named parameters
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod config;
pub mod json_parser;
pub mod parsers;
pub mod pythonic_parser;
pub mod response;
pub mod tools;
// Re-export main types and functions for convenience
pub use json_parser::{
CalledFunctionArguments, CalledFunctionParameters, try_tool_call_parse_json,
};
pub use parsers::{
JsonParserConfig, ToolCallConfig, ToolCallParserType, detect_and_parse_tool_call,
};
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use parsers::{detect_and_parse_tool_call, try_tool_call_parse};
pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType};
use super::json_parser::try_tool_call_parse_json;
use super::pythonic_parser::try_tool_call_parse_pythonic;
use super::response::ToolCallResponse;
/// Represents the format type for tool calls
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum ToolCallParserType {
/// JSON format: `{"name": "function", "arguments": {...}}`
Json,
Pythonic,
Harmony,
/// <function_call>```typescript
/// functions.get_current_weather({"location": "Shanghai"})
/// ```
Typescript,
Xml,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
/// Start token for list of parallel tool calls (e.g., "<TOOLCALLS>")
pub parallel_tool_calls_start_tokens: Vec<String>,
/// End token for list of parallel tool calls (e.g., "</TOOLCALLS>")
pub parallel_tool_calls_end_tokens: Vec<String>,
/// Start token for individual tool calls (e.g., "<TOOLCALL>")
pub tool_call_start_tokens: Vec<String>,
/// End token for individual tool calls (e.g., "</TOOLCALL>")
pub tool_call_end_tokens: Vec<String>,
/// The key for the function name in the tool call
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "name"
pub function_name_keys: Vec<String>,
/// The key for the arguments in the tool call
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "arguments"
pub arguments_keys: Vec<String>,
}
impl Default for JsonParserConfig {
fn default() -> Self {
Self {
parallel_tool_calls_start_tokens: vec![],
parallel_tool_calls_end_tokens: vec![],
tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
function_name_keys: vec!["name".to_string()],
arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
}
}
}
impl Default for ToolCallConfig {
fn default() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig::default(),
}
}
}
impl ToolCallConfig {
/// Default configuration for hermes tool calls
/// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
pub fn hermes() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["\n</tool_call>".to_string()],
..Default::default()
},
}
}
/// Default configuration for nemotron tool calls
/// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
pub fn nemotron_deci() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default()
},
}
}
pub fn llama3_json() -> Self {
// <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
// or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
pub fn mistral() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
pub fn phi4() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
}
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
pub fn try_tool_call_parse(
message: &str,
config: &ToolCallConfig,
......@@ -146,7 +20,8 @@ pub fn try_tool_call_parse(
anyhow::bail!("Harmony parser not implemented");
}
ToolCallParserType::Pythonic => {
anyhow::bail!("Pythonic parser not implemented");
let (results, normal_content) = try_tool_call_parse_pythonic(message)?;
Ok((results, normal_content))
}
ToolCallParserType::Typescript => {
anyhow::bail!("Typescript parser not implemented");
......@@ -169,6 +44,7 @@ pub fn detect_and_parse_tool_call(
parser_map.insert("llama3_json", ToolCallConfig::llama3_json());
parser_map.insert("mistral", ToolCallConfig::mistral());
parser_map.insert("phi4", ToolCallConfig::phi4());
parser_map.insert("pythonic", ToolCallConfig::pythonic());
parser_map.insert("default", ToolCallConfig::default()); // Add default key
// Handle None or empty string by defaulting to "default"
......@@ -190,6 +66,7 @@ pub fn detect_and_parse_tool_call(
// cargo test postprocessor::tool_calling::parsers
#[cfg(test)]
mod tests {
use super::super::config::JsonParserConfig;
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
......@@ -1197,4 +1074,38 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["from"], "New York");
assert_eq!(args["to"], "Los Angeles");
}
#[test]
fn test_pythonic_parser_basic_with_constants() {
let input = r#"[get_weather(location="San Francisco", unit="fahrenheit"), get_weather(location="New York", unit="fahrenheit")]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("pythonic")).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_pythonic_parser_with_constants_and_normal_text() {
let input = r#"Hey How are you? [get_weather(location="San Francisco", unit="fahrenheit"), get_weather(location="New York", unit="fahrenheit")]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("pythonic")).unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York");
assert_eq!(args["unit"], "fahrenheit");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use rustpython_parser::{
Mode,
ast::{Constant, Expr, Mod},
parse,
};
use serde_json::{Number, Value, json};
fn strip_text(message: &str) -> String {
// Remove unexpected python tags if any
message
.replace("<|python_start|>", "")
.replace("<|python_end|>", "")
}
fn get_regex_matches(message: &str) -> Vec<String> {
use regex::Regex;
// Format Structure: [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
let pattern = r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*?,\s*)*([a-zA-Z]+\w*=.*?\s?)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*?,\s*)*([a-zA-Z]+\w*=.*?\s*)?\)\s*)+\]";
let re = Regex::new(pattern).unwrap();
let mut matches = Vec::new();
for cap in re.find_iter(message) {
matches.push(cap.as_str().to_string());
}
matches
}
pub fn parse_tool_calls(src: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
let ast = parse(src, Mode::Expression, "<input>")?;
/*
AST: Expression(ModExpression {
range: (),
body: List(ExprList {
range: 0..25,
elts: [Call(...), Call(...)]
ctx: Load
})
})
*/
let body = match ast {
Mod::Expression(mod_expr) => mod_expr.body,
_ => return Ok(vec![]),
};
let elts = match *body {
Expr::List(expr_list) => expr_list.elts,
_ => return Ok(vec![]),
};
let mut res = Vec::with_capacity(elts.len());
for (idx, elt) in elts.iter().enumerate() {
let (func, keywords) = match elt {
Expr::Call(call) => (&call.func, &call.keywords),
_ => continue,
};
let name = match func.as_ref() {
Expr::Name(name) => name.id.clone(),
_ => continue,
};
let mut obj = serde_json::Map::new();
for keyword in keywords.iter() {
let Some(arg_ident) = keyword.arg.as_ref() else {
tracing::debug!(
"Skipping **kwargs in pythonic tool call for function {}",
name
);
continue;
};
match const_expr(&keyword.value) {
Ok(value) => {
obj.insert(arg_ident.to_string(), value);
}
Err(e) => {
tracing::debug!("Skipping non-constant argument {}: {}", arg_ident, e);
}
}
}
res.push(ToolCallResponse {
id: format!("call-{}", idx + 1),
tp: ToolCallType::Function,
function: CalledFunction {
name: name.to_string(),
// Safety: `Value::Object` is always valid JSON, so serialization cannot fail
arguments: serde_json::to_string(&Value::Object(obj))?,
},
});
}
Ok(res)
}
fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> {
match e {
Expr::Constant(constant) => Ok(match &constant.value {
Constant::Bool(b) => json!(b),
Constant::None => Value::Null,
Constant::Int(i) => {
// Try to downcast to i64/u64; fallback to string if out of range
use num_traits::ToPrimitive;
if let Some(v) = i.to_i64() {
Value::Number(Number::from(v))
} else if let Some(v) = i.to_u64() {
Value::Number(Number::from(v))
} else {
Value::String(i.to_string())
}
}
Constant::Float(f) => json!(f),
Constant::Str(s) => json!(s),
_ => return Err("unsupported constant type".into()),
}),
// Handle Python lists as expressions, not constants
Expr::List(expr_list) => {
let list_values: Result<Vec<Value>, Box<dyn std::error::Error>> =
expr_list.elts.iter().map(|e| const_expr(e)).collect();
Ok(json!(list_values?))
}
// Handle Python dictionaries as expressions, not constants
Expr::Dict(expr_dict) => {
let mut dict_map = std::collections::HashMap::new();
for (key_expr, value_expr) in expr_dict.keys.iter().zip(expr_dict.values.iter()) {
// Keys should be strings for JSON compatibility
// Handle the case where key_expr is Option<Expr>
let key = match key_expr {
Some(k) => match const_expr(k)? {
Value::String(s) => s,
other => other.to_string(),
},
None => {
return Err(
"dictionary unpacking (**kwargs) not supported in constants".into()
);
}
};
let value = const_expr(value_expr)?;
dict_map.insert(key, value);
}
Ok(json!(dict_map))
}
_ => Err("only constant values, lists, and dicts are allowed".into()),
}
}
pub fn try_tool_call_parse_pythonic(
message: &str,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let stripped = strip_text(message).trim().to_string();
// Early exit if no content
if stripped.is_empty() {
return Ok((vec![], Some(String::new())));
}
let matches = get_regex_matches(&stripped);
if matches.is_empty() {
return Ok((vec![], Some(stripped)));
}
let tool_response = parse_tool_calls(&matches[0]);
// normal text is everything before the first match
let normal_text = stripped
.split(&matches[0])
.next()
.unwrap() // Safety: `split()` always returns at least one element (the string before the first delimiter, or the entire string if delimiter not found)
.trim()
.to_string();
Ok((tool_response?, Some(normal_text)))
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
(call.function.name, args)
}
#[test]
fn test_strip_text() {
let message = "Hello, world!";
let stripped = strip_text(message);
assert_eq!(stripped, "Hello, world!");
let message = "<|python_start|>foo(a=1, b=2)<|python_end|>";
let stripped = strip_text(message);
assert_eq!(stripped, "foo(a=1, b=2)");
let message = "<|python_start|>foo(a=1, b=2)";
let stripped = strip_text(message);
assert_eq!(stripped, "foo(a=1, b=2)");
let message = "foo(a=1, b=2)<|python_end|>";
let stripped = strip_text(message);
assert_eq!(stripped, "foo(a=1, b=2)");
}
#[test]
fn test_get_regex_matches_simple_case() {
// Simple Case
let message = "[foo(a=1, b=2), bar(x=3)]";
let matches = get_regex_matches(message);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0], "[foo(a=1, b=2), bar(x=3)]");
}
#[test]
fn test_get_regex_matches_text_before_and_after() {
// Spacing in arg and value and text before and after
let message = "Hey yo ! [foo(a=1, b=2), bar(x= 3)] Hey yo";
let matches = get_regex_matches(message);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0], "[foo(a=1, b=2), bar(x= 3)]");
}
#[test]
fn test_get_regex_matches_new_line_in_arg_and_value() {
// New Line in Arg and value
let message = "Hey \n yo ! [foo(a=1,b=2), \n bar(x=3)] Hey yo";
let matches = get_regex_matches(message);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0], "[foo(a=1,b=2), \n bar(x=3)]");
}
#[test]
fn test_get_regex_matches_no_call() {
// No Call
let message = "Hey yo !";
let matches = get_regex_matches(message);
assert_eq!(matches.len(), 0);
}
#[test]
fn test_parse_tool_call_parse_pythonic_basic() {
let message = "[foo(a=1, b=2), bar(x=3)]";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone()); // TODO: Add support for normal text
assert_eq!(name, "foo");
assert_eq!(args["a"], 1);
assert_eq!(args["b"], 2);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "bar");
assert_eq!(args["x"], 3);
}
#[test]
fn test_parse_tool_call_parse_pythonic_with_text() {
let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
assert_eq!(content, Some("Hey yo !".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "foo");
assert_eq!(args["a"], 1);
assert_eq!(args["b"], 2);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "bar");
assert_eq!(args["x"], 3);
}
#[test]
fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() {
let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "foo");
assert_eq!(args["a"], 1);
assert_eq!(args["b"], 2);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "bar");
assert_eq!(args["x"], 3);
}
#[test]
fn test_parse_tool_call_parse_pythonic_with_no_calls() {
let message = "Hey \n yo !";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(result.is_empty());
assert_eq!(result.len(), 0)
}
#[test]
fn test_parse_tool_call_parse_pythonic_with_python_tags() {
let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "foo");
assert_eq!(args["a"], 1);
assert_eq!(args["b"], 2);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "bar");
assert_eq!(args["x"], 3);
}
#[test]
fn test_parse_tool_call_parse_pythonic_with_list_arg_values() {
let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "foo");
assert_eq!(args["a"], json!([1, 2, 3]));
assert_eq!(args["b"], 2);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "bar");
assert_eq!(args["x"], json!([3, 4, 5]));
}
#[test]
fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() {
let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "foo");
assert_eq!(args["a"], json!({"a": 1, "b": 2}));
assert_eq!(args["b"], 2);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "bar");
assert_eq!(args["x"], json!({"x": 3, "y": {"e": "f"}}));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use super::response::*;
// Import json_parser from postprocessor module
pub use super::json_parser::*;
pub use super::parsers::{ToolCallConfig, detect_and_parse_tool_call};
pub use super::config::ToolCallConfig;
pub use super::parsers::detect_and_parse_tool_call;
/// Try parsing a string as a structured tool call, for aggregation usage.
///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment