Unverified Commit 086ea4f0 authored by Elyas Mehtabuddin's avatar Elyas Mehtabuddin Committed by GitHub
Browse files

chore: Refactor tool calling for wider support in the future (#2393)

parent c12c2578
......@@ -27,6 +27,7 @@ pub mod mocker;
pub mod model_card;
pub mod model_type;
pub mod perf;
pub mod postprocessor;
pub mod preprocessor;
pub mod protocols;
pub mod recorder;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod tool_calling;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use serde_json::Value;
use uuid::Uuid;
mod request;
mod response;
pub use request::*;
pub use response::*;
/// Matches and processes tool calling patterns in LLM responses
///
/// Supports multiple formats for tool calls:
/// - Single/multiple function calls with parameters/arguments
/// - Auto or user selected tool usage
pub struct ToolCallingMatcher {
tool_choice: ToolChoice,
}
use super::parsers::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
// Same as CalledFunction with named parameters
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
......@@ -47,75 +23,6 @@ pub struct CalledFunctionArguments {
pub arguments: HashMap<String, Value>,
}
impl ToolCallingMatcher {
pub fn new(tool_choice: ToolChoice) -> anyhow::Result<Self> {
Ok(Self { tool_choice })
}
pub fn get_call(&self, message: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
if matches!(self.tool_choice, ToolChoice::None) {
return Ok(Vec::new());
}
if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(message) {
let id = format!("call-{}", Uuid::new_v4());
Ok(vec![ToolCallResponse {
id,
tp: ToolCallType::Function,
function: CalledFunction {
name: deser.name,
arguments: serde_json::to_string(&deser.parameters)?,
},
}])
} else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(message) {
Ok(deser
.into_iter()
.map(|deser| {
let id = format!("call-{}", Uuid::new_v4());
Ok(ToolCallResponse {
id,
tp: ToolCallType::Function,
function: CalledFunction {
name: deser.name,
arguments: serde_json::to_string(&deser.parameters)?,
},
})
})
.collect::<anyhow::Result<Vec<_>>>()?)
} else if let Ok(deser) = serde_json::from_str::<CalledFunctionArguments>(message) {
let id = format!("call-{}", Uuid::new_v4());
Ok(vec![ToolCallResponse {
id,
tp: ToolCallType::Function,
function: CalledFunction {
name: deser.name,
arguments: serde_json::to_string(&deser.arguments)?,
},
}])
} else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionArguments>>(message) {
Ok(deser
.into_iter()
.map(|deser| {
let id = format!("call-{}", Uuid::new_v4());
Ok(ToolCallResponse {
id,
tp: ToolCallType::Function,
function: CalledFunction {
name: deser.name,
arguments: serde_json::to_string(&deser.arguments)?,
},
})
})
.collect::<anyhow::Result<Vec<_>>>()?)
} else {
if matches!(self.tool_choice, ToolChoice::Tool(_)) {
anyhow::bail!("Tool choice was required but no tools were called.")
}
Ok(Vec::new())
}
}
}
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
......@@ -154,16 +61,25 @@ impl ToolCallingMatcher {
///
/// ```ignore
/// let input = r#"<TOOLCALL>[{ "name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
/// let result = try_parse_call_common(input)?;
/// let result = try_tool_call_parse_json(input)?;
/// assert!(result.is_some());
/// ```
pub fn try_parse_call_common(message: &str) -> anyhow::Result<Option<ToolCallResponse>> {
pub fn try_tool_call_parse_json(
message: &str,
config: &JsonParserConfig,
) -> anyhow::Result<Option<ToolCallResponse>> {
// Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config);
let trimmed = message.trim();
// Support <TOOLCALL>[ ... ] or <tool_call>[ ... ]
let json = if trimmed.starts_with("<TOOLCALL>[") && trimmed.ends_with("]</TOOLCALL>") {
let json = if let Some(stripped) = trimmed.strip_prefix("<TOOLCALL>[") {
if let Some(stripped) = stripped.strip_suffix("]</TOOLCALL>") {
tracing::debug!("Stripping <TOOLCALL> wrapper from tool call payload");
&trimmed["<TOOLCALL>[".len()..trimmed.len() - "]</TOOLCALL>".len()]
stripped
} else {
trimmed
}
// Support custom/LLM-formatted `<|python_tag|>` preamble
} else if let Some(stripped) = trimmed.strip_prefix("<|python_tag|>") {
......@@ -199,7 +115,7 @@ pub fn try_parse_call_common(message: &str) -> anyhow::Result<Option<ToolCallRes
if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
return parse(single.name, single.parameters).map(Some);
// CalledFunctinoArguments: Single { name, arguments }
// CalledFunctionArguments: Single { name, arguments }
// Example:
// {
// "name": "summarize",
......@@ -243,129 +159,3 @@ pub fn try_parse_call_common(message: &str) -> anyhow::Result<Option<ToolCallRes
Ok(None)
}
/// Try parsing a string as a structured tool call, for aggregation usage.
///
/// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_parse_tool_call_aggregate(
message: &str,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> {
let parsed = try_parse_call_common(message)?;
if let Some(parsed) = parsed {
Ok(Some(async_openai::types::ChatCompletionMessageToolCall {
id: parsed.id,
r#type: async_openai::types::ChatCompletionToolType::Function,
function: async_openai::types::FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
},
}))
} else {
Ok(None)
}
}
/// Try parsing a string as a structured tool call, for streaming (delta) usage.
///
/// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_parse_tool_call_stream(
message: &str,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> {
let parsed = try_parse_call_common(message)?;
if let Some(parsed) = parsed {
Ok(Some(
async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(parsed.id),
r#type: Some(async_openai::types::ChatCompletionToolType::Function),
function: Some(async_openai::types::FunctionCallStream {
name: Some(parsed.function.name),
arguments: Some(parsed.function.arguments),
}),
},
))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
(call.function.name, args)
}
#[test]
fn parses_single_parameters_object() {
let input = r#"{ "name": "hello", "parameters": { "x": 1, "y": 2 } }"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "hello");
assert_eq!(args["x"], 1);
assert_eq!(args["y"], 2);
}
#[test]
fn parses_single_arguments_object() {
let input = r#"{ "name": "world", "arguments": { "a": "abc", "b": 42 } }"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "world");
assert_eq!(args["a"], "abc");
assert_eq!(args["b"], 42);
}
#[test]
fn parses_vec_of_parameters_and_takes_last() {
let input = r#"[{ "name": "first", "parameters": { "a": 1 } }, { "name": "second", "parameters": { "b": 2 } }]"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "second");
assert_eq!(args["b"], 2);
}
#[test]
fn parses_vec_of_arguments_and_takes_last() {
let input = r#"[{ "name": "alpha", "arguments": { "a": "x" } }, { "name": "omega", "arguments": { "z": "y" } }]"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "omega");
assert_eq!(args["z"], "y");
}
#[test]
fn parses_toolcall_wrapped_payload() {
let input =
r#"<TOOLCALL>[{ "name": "wrapped", "parameters": { "foo": "bar" } }]</TOOLCALL>"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "wrapped");
assert_eq!(args["foo"], "bar");
}
#[test]
fn parses_python_tag_prefixed_payload() {
let input = r#"<|python_tag|>{ "name": "pyfunc", "arguments": { "k": "v" } }"#;
let result = try_parse_call_common(input).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "pyfunc");
assert_eq!(args["k"], "v");
}
#[test]
fn returns_none_on_invalid_input() {
let input = r#"not even json"#;
let result = try_parse_call_common(input).unwrap();
assert!(result.is_none());
}
#[test]
fn returns_none_on_valid_json_wrong_shape() {
let input = r#"{ "foo": "bar" }"#;
let result = try_parse_call_common(input).unwrap();
assert!(result.is_none());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod json_parser;
pub mod parsers;
pub mod response;
pub mod tools;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::json_parser::try_tool_call_parse_json;
use super::response::ToolCallResponse;
/// Represents the format type for tool calls
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum ToolCallParserType {
/// JSON format: `{"name": "function", "arguments": {...}}`
Json,
Pythonic,
Harmony,
/// <function_call>```typescript
/// functions.get_current_weather({"location": "Shanghai"})
/// ```
Typescript,
Xml,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
/// Start token for list of parallel tool calls (e.g., "<TOOLCALLS>")
pub parallel_tool_calls_start_tokens: Vec<String>,
/// End token for list of parallel tool calls (e.g., "</TOOLCALLS>")
pub parallel_tool_calls_end_tokens: Vec<String>,
/// Start token for individual tool calls (e.g., "<TOOLCALL>")
pub tool_call_start_tokens: Vec<String>,
/// End token for individual tool calls (e.g., "</TOOLCALL>")
pub tool_call_end_tokens: Vec<String>,
/// The key for the function name in the tool call
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "name"
pub function_name_keys: Vec<String>,
/// The key for the arguments in the tool call
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "arguments"
pub arguments_keys: Vec<String>,
}
impl Default for JsonParserConfig {
fn default() -> Self {
Self {
parallel_tool_calls_start_tokens: vec![],
parallel_tool_calls_end_tokens: vec![],
tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
function_name_keys: vec!["name".to_string()],
arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
}
}
}
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
impl Default for ToolCallConfig {
fn default() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig::default(),
}
}
}
pub fn try_tool_call_parse(
message: &str,
config: &ToolCallConfig,
) -> anyhow::Result<Option<ToolCallResponse>> {
// Use match statement (Rust's switch statement) to call the appropriate parser
match config.format {
ToolCallParserType::Json => try_tool_call_parse_json(message, &config.json),
ToolCallParserType::Harmony => {
anyhow::bail!("Harmony parser not implemented");
}
ToolCallParserType::Pythonic => {
anyhow::bail!("Pythonic parser not implemented");
}
ToolCallParserType::Typescript => {
anyhow::bail!("Typescript parser not implemented");
}
ToolCallParserType::Xml => {
anyhow::bail!("Xml parser not implemented");
}
}
}
// Tests
// cargo test postprocessor::tool_calling::parsers
#[cfg(test)]
mod tests {
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
(call.function.name, args)
}
#[test]
fn parses_single_parameters_object() {
let input = r#"{ "name": "hello", "parameters": { "x": 1, "y": 2 } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "hello");
assert_eq!(args["x"], 1);
assert_eq!(args["y"], 2);
}
#[test]
fn parses_single_arguments_object() {
let input = r#"{ "name": "world", "arguments": { "a": "abc", "b": 42 } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "world");
assert_eq!(args["a"], "abc");
assert_eq!(args["b"], 42);
}
#[test]
fn parses_vec_of_parameters_and_takes_last() {
let input = r#"[{ "name": "first", "parameters": { "a": 1 } }, { "name": "second", "parameters": { "b": 2 } }]"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "second");
assert_eq!(args["b"], 2);
}
#[test]
fn parses_vec_of_arguments_and_takes_last() {
let input = r#"[{ "name": "alpha", "arguments": { "a": "x" } }, { "name": "omega", "arguments": { "z": "y" } }]"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "omega");
assert_eq!(args["z"], "y");
}
#[test]
fn parses_toolcall_wrapped_payload() {
let input =
r#"<TOOLCALL>[{ "name": "wrapped", "parameters": { "foo": "bar" } }]</TOOLCALL>"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "wrapped");
assert_eq!(args["foo"], "bar");
}
#[test]
fn parses_python_tag_prefixed_payload() {
let input = r#"<|python_tag|>{ "name": "pyfunc", "arguments": { "k": "v" } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "pyfunc");
assert_eq!(args["k"], "v");
}
#[test]
fn returns_none_on_invalid_input() {
let input = r#"not even json"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
assert!(result.is_none());
}
#[test]
fn returns_none_on_valid_json_wrong_shape() {
let input = r#"{ "foo": "bar" }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
assert!(result.is_none());
}
// Tests for real model outputs - disabled by default
#[test]
#[ignore]
fn test_nvidia_llama3_nemotron_super_49b_simple() {
let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
</think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_qwen_qwq_32b_simple() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_nousresearch_hermes3_llama31_8b_simple() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_ibm_granite_40_tiny_preview_simple() {
let input = r#"[{"arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}, "name": "get_weather"}]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_mistralai_mistral_7b_instruct_v03_simple() {
let input = r#" [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_meta_llama_llama31_8b_instruct_simple() {
let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["parameters".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_internlm_internlm2_5_7b_chat_simple() {
let input = r#"San Francisco's weather is known for its mild climate with plenty of fog, especially along the coast. Here's an overview of the weather in Fahrenheit:
- **Summer (June to August)**: Average highs range from the mid-60s to low 70s Fahrenheit, with cooler mornings and evenings. Coastal areas may be cooler than inland spots.
Remember, San Francisco weather can be quite unpredictable, particularly with its famous fog, which can significantly lower temperatures. Always check a local weather forecast for the most accurate and up-to-date information."#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
assert!(result.is_none()); // This model doesn't produce tool calls
}
#[test]
#[ignore]
fn test_ai21labs_ai21_jamba_15_mini_simple() {
let input = r#" [
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
fn test_salesforce_llama_xlam_2_8b_fc_r_simple() {
let input = r#"[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub use super::response::*;
pub use crate::preprocessor::tools::request::*;
// Import json_parser from postprocessor module
pub use super::json_parser::*;
pub use super::parsers::{try_tool_call_parse, ToolCallConfig};
/// Try parsing a string as a structured tool call, for aggregation usage.
///
/// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_tool_call_parse_aggregate(
message: &str,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> {
let config = ToolCallConfig::default();
let parsed = try_tool_call_parse(message, &config)?;
if let Some(parsed) = parsed {
Ok(Some(async_openai::types::ChatCompletionMessageToolCall {
id: parsed.id,
r#type: async_openai::types::ChatCompletionToolType::Function,
function: async_openai::types::FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
},
}))
} else {
Ok(None)
}
}
/// Try parsing a string as a structured tool call, for streaming (delta) usage.
///
/// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_tool_call_parse_stream(
message: &str,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> {
let config = ToolCallConfig::default();
let parsed = try_tool_call_parse(message, &config)?;
if let Some(parsed) = parsed {
Ok(Some(
async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: Some(parsed.id),
r#type: Some(async_openai::types::ChatCompletionToolType::Function),
function: Some(async_openai::types::FunctionCallStream {
name: Some(parsed.function.name),
arguments: Some(parsed.function.arguments),
}),
},
))
} else {
Ok(None)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod request;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
......
......@@ -164,7 +164,9 @@ impl DeltaAggregator {
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() {
if let Ok(Some(tool_call)) =
crate::preprocessor::tools::try_parse_tool_call_aggregate(&choice.text)
crate::postprocessor::tool_calling::tools::try_tool_call_parse_aggregate(
&choice.text,
)
{
tracing::debug!(
tool_call_id = %tool_call.id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment