Unverified Commit 5c197b39 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: add deepseek_v3_1 tool parser with lib refactoring (#2832)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent c6becbc8
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::json::JsonParserType;
/// Represents the format type for tool calls
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum ToolCallParserType {
......@@ -33,6 +35,10 @@ pub struct JsonParserConfig {
/// i.e. `{"name": "function", "arguments": {...}}` it would be
/// "arguments"
pub arguments_keys: Vec<String>,
/// The type of JSON parser to use
#[serde(default)]
pub parser_type: JsonParserType,
}
impl Default for JsonParserConfig {
......@@ -44,6 +50,7 @@ impl Default for JsonParserConfig {
tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
function_name_keys: vec!["name".to_string()],
arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
parser_type: JsonParserType::Basic,
}
}
}
......@@ -145,4 +152,16 @@ impl ToolCallConfig {
},
}
}
pub fn deepseek_v3_1() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
parser_type: JsonParserType::DeepseekV31,
..Default::default()
},
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod harmony_parser;
pub use super::{config, response};
pub use harmony_parser::parse_tool_calls_harmony;
......@@ -149,7 +149,7 @@ fn try_parse_normal_text(input: &str, start_token: &str) -> String {
/// let result = try_tool_call_parse_json(input)?;
/// assert!(result.is_some());
/// ```
pub fn try_tool_call_parse_json(
pub fn try_tool_call_parse_basic_json(
message: &str,
config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use regex::Regex;
use serde_json::Value;
use std::sync::OnceLock;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
static DEEPSEEK_V3_1_OUTER_REGEX: OnceLock<Regex> = OnceLock::new();
static DEEPSEEK_V3_1_INNER_REGEX: OnceLock<Regex> = OnceLock::new();
pub fn get_deepseek_v3_1_outer_regex() -> &'static Regex {
DEEPSEEK_V3_1_OUTER_REGEX.get_or_init(|| {
// Outer regex: matches the entire tool call block
Regex::new(r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>")
.expect("Failed to compile deepseek v3.1 outer regex pattern")
})
}
pub fn get_deepseek_v3_1_inner_regex() -> &'static Regex {
DEEPSEEK_V3_1_INNER_REGEX.get_or_init(|| {
// Inner regex: captures function name and arguments between sep tokens
Regex::new(r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)<|tool▁call▁end|>")
.expect("Failed to compile deepseek v3.1 inner regex pattern")
})
}
pub fn parse_tool_calls_deepseek_v3_1(
message: &str,
config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁calls▁end|><|end▁of▁sentence|>
let trimmed = message.trim();
let tool_call_start_tokens = &config.tool_call_start_tokens;
// Early exit if no content or tool_call_start_tokens is empty
if trimmed.is_empty() || tool_call_start_tokens.is_empty() {
return Ok((vec![], Some(trimmed.to_string())));
}
// If tool call start token is not present then, no tool calls are there, return empty tool calls and the original trimmed string
if let Some(start_token) = tool_call_start_tokens.first() {
if !trimmed.contains(start_token) {
return Ok((vec![], Some(trimmed.to_string())));
}
} else {
// Invalid start token
return Ok((vec![], Some(trimmed.to_string())));
}
let outer_re = get_deepseek_v3_1_outer_regex();
let inner_re = get_deepseek_v3_1_inner_regex();
let outer_matches = outer_re.find_iter(trimmed);
let mut tool_calls: Vec<ToolCallResponse> = Vec::new();
let mut call_idx = 0usize;
// Two matches are there, first one using outer regex to extract multiple tool calls
// Second one using inner regex to extract the structure of the tool call
for outer_match in outer_matches {
for grp in inner_re.captures_iter(outer_match.as_str()) {
let Some(function_name) = grp.get(1).map(|x| x.as_str()) else {
continue; // Skip if function name is not found
};
let Some(arg_match) = grp.get(2) else {
continue; // Skip if arguments Match is not found.
};
let arguments = match serde_json::from_str::<Value>(arg_match.as_str()) {
Ok(args) => args,
Err(_) => {
continue; // Skip if arguments are not valid JSON
}
};
call_idx += 1;
tool_calls.push(ToolCallResponse {
id: format!("call-{}", call_idx),
tp: ToolCallType::Function,
function: CalledFunction {
name: function_name.to_string(),
arguments: serde_json::to_string(&arguments)?,
},
});
}
}
// Fast path: if no tool calls, just return early
// This may happen due to invalid json or any other parsing error reasons
if tool_calls.is_empty() {
return Ok((vec![], Some(trimmed.to_string())));
}
// Safety: We already checked above that tool_call_start_tokens.first() is Some
let start_token = tool_call_start_tokens.first().unwrap();
let normal_text = trimmed
.split_once(start_token)
.map(|(before, _)| before.to_string())
.unwrap_or_else(|| trimmed.to_string());
Ok((tool_calls, Some(normal_text)))
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
(call.function.name, args)
}
#[test]
fn test_parse_tool_calls_deepseek_v3_1_basic() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "Tokyo");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "Paris");
}
#[test]
fn test_parse_tool_calls_deepseek_v3_1_with_normal_text() {
let text = r#"The following tool call retrieves weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "New York"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(
content,
Some("The following tool call retrieves weather information: ".to_string())
);
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "New York");
}
#[test]
fn test_parse_tool_calls_deepseek_v3_1_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
}
#[test]
fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_multiple_args() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Berlin", "units": "metric"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast<|tool▁sep|>{"location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality<|tool▁sep|>{"location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "Berlin");
assert_eq!(args["units"], "metric");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather_forecast");
assert_eq!(args["location"], "Berlin");
assert_eq!(args["days"], 7);
assert_eq!(args["units"], "imperial");
let (name, args) = extract_name_and_args(result[2].clone());
assert_eq!(name, "get_air_quality");
assert_eq!(args["location"], "Berlin");
assert_eq!(args["radius"], 50);
}
#[test]
fn test_parse_tool_calls_deepseek_v3_1_with_invalid_json() {
// Everything is normal text in case of invalid json
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
#[test]
fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_normal_text() {
// Everything is normal text in case of invalid json
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast宽带}{location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality宽带}{location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod base_json_parser;
pub mod deepseek_parser;
pub use super::{config, response};
pub use base_json_parser::try_tool_call_parse_basic_json;
pub use deepseek_parser::parse_tool_calls_deepseek_v3_1;
pub use super::config::JsonParserConfig;
pub use super::response::ToolCallResponse;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum JsonParserType {
// Basic is generic json parser which can handle most of the cases
Basic,
// Model Specific JSON Parsers
DeepseekV31,
}
impl Default for JsonParserType {
fn default() -> Self {
Self::Basic
}
}
pub fn try_tool_call_parse_json(
message: &str,
config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
match config.parser_type {
JsonParserType::Basic => try_tool_call_parse_basic_json(message, config),
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config),
}
}
......@@ -2,15 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
pub mod config;
pub mod harmony_parser;
pub mod json_parser;
pub mod harmony;
pub mod json;
pub mod parsers;
pub mod pythonic_parser;
pub mod pythonic;
pub mod response;
pub mod tools;
// Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use harmony::parse_tool_calls_harmony;
pub use json::try_tool_call_parse_json;
pub use parsers::{detect_and_parse_tool_call, try_tool_call_parse};
pub use pythonic::try_tool_call_parse_pythonic;
pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
......@@ -2,9 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType};
use super::harmony_parser::parse_tool_calls_harmony;
use super::json_parser::try_tool_call_parse_json;
use super::pythonic_parser::try_tool_call_parse_pythonic;
use super::harmony::parse_tool_calls_harmony;
use super::json::try_tool_call_parse_json;
use super::pythonic::try_tool_call_parse_pythonic;
use super::response::ToolCallResponse;
use std::collections::HashMap;
use std::sync::OnceLock;
......@@ -22,6 +22,7 @@ pub fn get_tool_parser_map() -> &'static HashMap<&'static str, ToolCallConfig> {
map.insert("phi4", ToolCallConfig::phi4());
map.insert("pythonic", ToolCallConfig::pythonic());
map.insert("harmony", ToolCallConfig::harmony());
map.insert("deepseek_v3_1", ToolCallConfig::deepseek_v3_1());
map.insert("default", ToolCallConfig::default());
map
})
......@@ -111,6 +112,7 @@ mod tests {
"phi4",
"default",
"pythonic",
"deepseek_v3_1",
];
for parser in available_parsers {
assert!(parsers.contains(&parser));
......@@ -1170,4 +1172,18 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["location"], "San Francisco");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_deepseek_v3_1_parser_basic() {
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("deepseek_v3_1")).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "Tokyo");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "Paris");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod pythonic_parser;
pub use super::{config, response};
pub use pythonic_parser::try_tool_call_parse_pythonic;
......@@ -22,6 +22,7 @@ fn get_pythonic_regex() -> &'static Regex {
Regex::new(pattern).expect("Failed to compile pythonic regex pattern")
})
}
fn strip_text(message: &str) -> String {
// Remove unexpected python tags if any
message
......@@ -31,7 +32,6 @@ fn strip_text(message: &str) -> String {
fn get_regex_matches(message: &str) -> Vec<String> {
let re = get_pythonic_regex();
let mut matches = Vec::new();
for cap in re.find_iter(message) {
matches.push(cap.as_str().to_string());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment