Unverified Commit 44e8600a authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

refactor: New config types for tool calls (#4575)

* Why?

We would like the ability to configure different parser types. Prior to
this commit, only the JSON parser could be configured.

* What?

This commit refactors the tool parser config in the following ways:
- the `format` and `json` fields of `ToolParserConfig` are merged into
  a single `config` field that is a "discriminated union" type. Each
  parser type can declare its own configuration options.
- a `XmlParserConfig` is defined with a default factory method that
  corresponds to the Qwen3 coder configuration.
- affected calls and tests are adjusted.
parent 262cce76
...@@ -895,14 +895,14 @@ impl JailedStreamBuilder { ...@@ -895,14 +895,14 @@ impl JailedStreamBuilder {
if let Some(config) = parser_map.get(parser_name.as_str()) { if let Some(config) = parser_map.get(parser_name.as_str()) {
// Auto-populate start sequences if none configured // Auto-populate start sequences if none configured
if self.jail_start_sequences.is_empty() { if self.jail_start_sequences.is_empty() {
self.jail_start_sequences = config.json.tool_call_start_tokens.clone(); self.jail_start_sequences = config.parser_config.tool_call_start_tokens();
} }
// Auto-populate end sequences if none configured // Auto-populate end sequences if none configured
if self.jail_end_sequences.is_empty() { if self.jail_end_sequences.is_empty() {
self.jail_end_sequences = config self.jail_end_sequences = config
.json .parser_config
.tool_call_end_tokens .tool_call_end_tokens()
.iter() .iter()
.filter(|&s| !s.is_empty()) .filter(|&s| !s.is_empty())
.cloned() .cloned()
...@@ -922,7 +922,7 @@ impl JailedStreamBuilder { ...@@ -922,7 +922,7 @@ impl JailedStreamBuilder {
let parser_map = get_tool_parser_map(); let parser_map = get_tool_parser_map();
if let Some(config) = parser_map.get(parser_name.as_str()) { if let Some(config) = parser_map.get(parser_name.as_str()) {
// Add start tokens from the parser config // Add start tokens from the parser config
all_patterns.extend(config.json.tool_call_start_tokens.clone()); all_patterns.extend(config.parser_config.tool_call_start_tokens());
} }
} }
......
...@@ -1989,6 +1989,239 @@ mod tests { ...@@ -1989,6 +1989,239 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_jailed_stream_qwen3_coder_parser() {
// Input:
// "I'll call a function. "
// + "<tool_call><function=get_weather><parameter=location>San Francisco</parameter><parameter=unit>celsius</parameter></function></tool_call>"
// + " Done."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("I'll call a function. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
create_mock_response_chunk(
"<parameter=location>San Francisco</parameter>".to_string(),
0,
),
create_mock_response_chunk("<parameter=unit>celsius</parameter>".to_string(), 0),
create_mock_response_chunk("</function>".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" Done.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()].
test_utils::assert_content(&results[0], "I'll call a function. ");
test_utils::assert_tool_call(
&results[1],
"get_weather",
serde_json::json!({"location": "San Francisco", "unit": "celsius"}),
);
test_utils::assert_content(&results[2], " Done.");
// Verify content reconstruction excludes tool calls.
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "I'll call a function. Done.");
}
#[tokio::test]
async fn test_jailed_stream_qwen3_coder_multiple_params() {
let chunks = vec![
create_mock_response_chunk("Let me search for that. ".to_string(), 0),
create_mock_response_chunk(
"<tool_call><function=web_search><parameter=query>Rust programming</parameter><parameter=max_results>10</parameter><parameter=filter>recent</parameter></function></tool_call>".to_string(),
0,
),
create_mock_response_chunk(" Searching now.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert_eq!(results.len(), 3, "Should have 3 chunks");
test_utils::assert_content(&results[0], "Let me search for that. ");
test_utils::assert_tool_call(
&results[1],
"web_search",
serde_json::json!({
"query": "Rust programming",
"max_results": 10,
"filter": "recent"
}),
);
test_utils::assert_content(&results[2], " Searching now.");
}
#[tokio::test]
async fn test_jailed_stream_xml_parser_config_tokens_auto_population() {
// Tests that parser config tokens are auto-populated when using `.tool_call_parser()`.
// This verifies the jail system reads `tool_call_start_token` and `tool_call_end_token`
// from the `qwen3_coder` parser config.
let chunks = vec![
create_mock_response_chunk("Before tool call. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0), // Default qwen3_coder token
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
create_mock_response_chunk("<parameter=city>Seattle</parameter>".to_string(), 0),
create_mock_response_chunk("</function>".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0), // Default qwen3_coder token
create_mock_response_chunk(" After tool call.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream using ONLY `.tool_call_parser()`.
// This should auto-populate jail sequences from the qwen3_coder config
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
test_utils::assert_content(&results[0], "Before tool call. ");
test_utils::assert_tool_call(
&results[1],
"get_weather",
serde_json::json!({"city": "Seattle"}),
);
test_utils::assert_content(&results[2], " After tool call.");
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "Before tool call. After tool call.");
}
#[tokio::test]
async fn test_jailed_stream_xml_manual_sequences_prevent_auto_population() {
// Tests that manually setting jail sequences prevents auto-population.
// This verifies the builder respects manual configuration over auto-population.
//
// When custom sequences are set, the default parser tokens (<tool_call>) should
// NOT trigger jailing and should pass through as regular content.
let chunks = vec![
create_mock_response_chunk("Text with ".to_string(), 0),
// Default qwen3_coder token - should NOT trigger jailing.
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk("should not jail".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" because custom ".to_string(), 0),
// Custom marker - this SHOULD trigger jailing since we register it below.
create_mock_response_chunk("[[START]]".to_string(), 0),
create_mock_response_chunk("jailed content".to_string(), 0),
create_mock_response_chunk("[[END]]".to_string(), 0),
create_mock_response_chunk(" text.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Set custom jail sequences - this should prevent auto-population.
// The default <tool_call> tokens should NOT trigger jailing.
let jail = JailedStream::builder()
.jail_start_sequence("[[START]]")
.jail_end_sequence("[[END]]")
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
// The exact number of chunks depends on emission mode (packed vs single-choice-per-chunk)
// but we can verify the key behaviors:
// 1. Default <tool_call> tokens pass through as content (not jailed)
// 2. Custom [[START]]/[[END]] markers trigger jailing
// 3. No tool calls are extracted (because jailed content isn't valid XML)
// Find chunk(s) containing the default tokens that passed through.
let default_token_chunks: Vec<_> = results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
})
.filter(|content| {
content.contains("<tool_call>") || content.contains("should not jail")
})
.collect();
assert!(
!default_token_chunks.is_empty(),
"Default <tool_call> should pass through as content when manual sequences are set"
);
// Find chunk containing the jailed content that was released.
let jailed_chunk = results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
})
.find(|content| content.contains("[[START]]") && content.contains("jailed content"));
assert!(
jailed_chunk.is_some(),
"Custom markers should trigger jailing and accumulated content should be released"
);
// Since the custom markers include non-XML content, the parser should not extract tool calls.
// The accumulated content "[[START]]jailed content[[END]]", although compatible with the
// way we configured `jail` above, is not consistent with what `qwen_coder` expects, and
// there is (at time of writing) no way to pass a parser instance - only a string that
// internally gets mapped to default way of instantiating a particular parser.
let tool_call_count = results
.iter()
.filter(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.tool_calls.as_ref())
.map(|tc| !tc.is_empty())
.unwrap_or(false)
})
.count();
assert_eq!(
tool_call_count, 0,
"Should have 0 tool calls because jailed content doesn't match XML format"
);
// Verify content reconstruction - all original content should be preserved.
let reconstructed = test_utils::reconstruct_content(&results);
assert!(
reconstructed.contains("<tool_call>") && reconstructed.contains("should not jail"),
"Reconstructed content should include default tokens that passed through"
);
assert!(
reconstructed.contains("[[START]]") && reconstructed.contains("jailed content"),
"Reconstructed content should include jailed content with custom markers"
);
}
#[tokio::test] #[tokio::test]
async fn test_jailed_stream_mistral_false_positive_curly() { async fn test_jailed_stream_mistral_false_positive_curly() {
// Curly brace in normal text should not trigger tool call detection for mistral // Curly brace in normal text should not trigger tool call detection for mistral
......
...@@ -3,20 +3,6 @@ ...@@ -3,20 +3,6 @@
use super::json::JsonParserType; use super::json::JsonParserType;
/// Represents the format type for tool calls
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum ToolCallParserType {
/// JSON format: `{"name": "function", "arguments": {...}}`
Json,
Pythonic,
Harmony,
/// <function_call>```typescript
/// functions.get_current_weather({"location": "Shanghai"})
/// ```
Typescript,
Xml,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig { pub struct JsonParserConfig {
/// Start token for individual tool calls (e.g., "<TOOLCALL>") /// Start token for individual tool calls (e.g., "<TOOLCALL>")
...@@ -54,21 +40,83 @@ impl Default for JsonParserConfig { ...@@ -54,21 +40,83 @@ impl Default for JsonParserConfig {
} }
} }
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct XmlParserConfig {
/// Start token for individual tool calls (e.g., "<tool_call>")
pub tool_call_start_token: String,
/// End token for individual tool calls (e.g., "</tool_call>")
pub tool_call_end_token: String,
/// Start token for function name (e.g., "<function=")
pub function_start_token: String,
/// End token for function (e.g., "</function>")
pub function_end_token: String,
/// Start token for parameter (e.g., "<parameter=")
pub parameter_start_token: String,
/// End token for parameter (e.g., "</parameter>")
pub parameter_end_token: String,
}
impl Default for XmlParserConfig {
fn default() -> Self {
Self {
tool_call_start_token: "<tool_call>".to_string(),
tool_call_end_token: "</tool_call>".to_string(),
function_start_token: "<function=".to_string(),
function_end_token: "</function>".to_string(),
parameter_start_token: "<parameter=".to_string(),
parameter_end_token: "</parameter>".to_string(),
}
}
}
/// Parser-specific configuration
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ParserConfig {
Json(JsonParserConfig),
Xml(XmlParserConfig),
Pythonic,
Harmony(JsonParserConfig),
Typescript,
}
impl ParserConfig {
/// Get the tool call start tokens for this parser configuration
/// Returns a vector of start tokens that indicate the beginning of a tool call
pub fn tool_call_start_tokens(&self) -> Vec<String> {
match self {
ParserConfig::Json(config) => config.tool_call_start_tokens.clone(),
ParserConfig::Harmony(config) => config.tool_call_start_tokens.clone(),
ParserConfig::Xml(config) => vec![config.tool_call_start_token.clone()],
ParserConfig::Pythonic => vec![],
ParserConfig::Typescript => vec![],
}
}
/// Get the tool call end tokens for this parser configuration
/// Returns a vector of end tokens that indicate the end of a tool call
pub fn tool_call_end_tokens(&self) -> Vec<String> {
match self {
ParserConfig::Json(config) => config.tool_call_end_tokens.clone(),
ParserConfig::Harmony(config) => config.tool_call_end_tokens.clone(),
ParserConfig::Xml(config) => vec![config.tool_call_end_token.clone()],
ParserConfig::Pythonic => vec![],
ParserConfig::Typescript => vec![],
}
}
}
/// Configuration for parsing tool calls with different formats /// Configuration for parsing tool calls with different formats
// TODO(2ez4bz): refactor to allow other parser configs than `JsonParserConfig`.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig { pub struct ToolCallConfig {
/// The format type for tool calls /// Parser-specific configuration.
pub format: ToolCallParserType, pub parser_config: ParserConfig,
/// The config for the JSON parser
pub json: JsonParserConfig,
} }
impl Default for ToolCallConfig { impl Default for ToolCallConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig::default()),
json: JsonParserConfig::default(),
} }
} }
} }
...@@ -78,12 +126,11 @@ impl ToolCallConfig { ...@@ -78,12 +126,11 @@ impl ToolCallConfig {
/// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call> /// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
pub fn hermes() -> Self { pub fn hermes() -> Self {
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()], tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()], tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default() ..Default::default()
}, }),
} }
} }
...@@ -91,12 +138,11 @@ impl ToolCallConfig { ...@@ -91,12 +138,11 @@ impl ToolCallConfig {
/// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL> /// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
pub fn nemotron_deci() -> Self { pub fn nemotron_deci() -> Self {
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()], tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()], tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default() ..Default::default()
}, }),
} }
} }
...@@ -104,52 +150,47 @@ impl ToolCallConfig { ...@@ -104,52 +150,47 @@ impl ToolCallConfig {
// <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} } // <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
// or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} } // or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()], tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()], tool_call_end_tokens: vec!["".to_string()],
..Default::default() ..Default::default()
}, }),
} }
} }
pub fn mistral() -> Self { pub fn mistral() -> Self {
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()], tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()], tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()],
..Default::default() ..Default::default()
}, }),
} }
} }
pub fn phi4() -> Self { pub fn phi4() -> Self {
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()], tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()], tool_call_end_tokens: vec!["".to_string()],
..Default::default() ..Default::default()
}, }),
} }
} }
pub fn pythonic() -> Self { pub fn pythonic() -> Self {
Self { Self {
format: ToolCallParserType::Pythonic, parser_config: ParserConfig::Pythonic,
json: JsonParserConfig::default(), // This is noop here, but we keep it for consistency
} }
} }
pub fn harmony() -> Self { pub fn harmony() -> Self {
Self { Self {
format: ToolCallParserType::Harmony, parser_config: ParserConfig::Harmony(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()], tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}, }),
} }
} }
...@@ -161,8 +202,7 @@ impl ToolCallConfig { ...@@ -161,8 +202,7 @@ impl ToolCallConfig {
// so the tool parser can properly consume all tool call tokens. // so the tool parser can properly consume all tool call tokens.
// https://huggingface.co/deepseek-ai/DeepSeek-V3.1#toolcall // https://huggingface.co/deepseek-ai/DeepSeek-V3.1#toolcall
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec![ tool_call_start_tokens: vec![
"<|tool▁calls▁begin|>".to_string(), "<|tool▁calls▁begin|>".to_string(),
// "<|tool▁call▁begin|>".to_string(), // "<|tool▁call▁begin|>".to_string(),
...@@ -174,7 +214,7 @@ impl ToolCallConfig { ...@@ -174,7 +214,7 @@ impl ToolCallConfig {
tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()], tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
parser_type: JsonParserType::DeepseekV31, parser_type: JsonParserType::DeepseekV31,
..Default::default() ..Default::default()
}, }),
} }
} }
...@@ -183,22 +223,20 @@ impl ToolCallConfig { ...@@ -183,22 +223,20 @@ impl ToolCallConfig {
// <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|> // <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
// There are some differences between DeepSeek V3 and DeepSeek V3.1 // There are some differences between DeepSeek V3 and DeepSeek V3.1
Self { Self {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()], tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
parser_type: JsonParserType::DeepseekV3, parser_type: JsonParserType::DeepseekV3,
..Default::default() ..Default::default()
}, }),
} }
} }
pub fn qwen3_coder() -> Self { pub fn qwen3_coder() -> Self {
// <tool_call><function=name><parameter=key>value</parameter></function></tool_call> // <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
Self { Self {
format: ToolCallParserType::Xml, parser_config: ParserConfig::Xml(XmlParserConfig::default()),
json: JsonParserConfig::default(), // Not used for qwen3_coder but kept for consistency.
} }
} }
} }
...@@ -271,7 +271,10 @@ mod tests { ...@@ -271,7 +271,10 @@ mod tests {
#[test] #[test]
fn test_parse_tool_calls_deepseek_v3_1_basic() { fn test_parse_tool_calls_deepseek_v3_1_basic() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -286,7 +289,10 @@ mod tests { ...@@ -286,7 +289,10 @@ mod tests {
#[test] #[test]
fn test_parse_tool_calls_deepseek_v3_1_with_normal_text() { fn test_parse_tool_calls_deepseek_v3_1_with_normal_text() {
let text = r#"The following tool call retrieves weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "New York"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; let text = r#"The following tool call retrieves weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "New York"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!( assert_eq!(
content, content,
...@@ -301,7 +307,10 @@ mod tests { ...@@ -301,7 +307,10 @@ mod tests {
#[test] #[test]
fn test_parse_tool_calls_deepseek_v3_1_without_tool_call_start_token() { fn test_parse_tool_calls_deepseek_v3_1_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#; let text = r#"<|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.to_string())); assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
...@@ -310,7 +319,10 @@ mod tests { ...@@ -310,7 +319,10 @@ mod tests {
#[test] #[test]
fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_multiple_args() { fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_multiple_args() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Berlin", "units": "metric"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast<|tool▁sep|>{"location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality<|tool▁sep|>{"location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Berlin", "units": "metric"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast<|tool▁sep|>{"location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality<|tool▁sep|>{"location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
...@@ -333,7 +345,10 @@ mod tests { ...@@ -333,7 +345,10 @@ mod tests {
fn test_parse_tool_calls_deepseek_v3_1_with_invalid_json() { fn test_parse_tool_calls_deepseek_v3_1_with_invalid_json() {
// Everything is normal text in case of invalid json // Everything is normal text in case of invalid json
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#; let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
...@@ -343,7 +358,10 @@ mod tests { ...@@ -343,7 +358,10 @@ mod tests {
fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_normal_text() { fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_normal_text() {
// Everything is normal text in case of invalid json // Everything is normal text in case of invalid json
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast宽带}{location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality宽带}{location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|>"#; let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast宽带}{location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality宽带}{location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
...@@ -364,7 +382,10 @@ mod tests { ...@@ -364,7 +382,10 @@ mod tests {
"Summarize the codebase purpose and functionality", "status": "pending", "activeForm": "Summarize the codebase purpose and functionality", "status": "pending", "activeForm":
"Summarizing the codebase purpose and "Summarizing the codebase purpose and
functionality"}]}<|tool▁call▁end|><|tool▁calls▁end|>"#; functionality"}]}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (tool_call_results, normal_content) = let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
...@@ -414,7 +435,10 @@ mod detect_parser_tests { ...@@ -414,7 +435,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token() { fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#; let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config); let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result); assert!(result);
} }
...@@ -422,7 +446,10 @@ mod detect_parser_tests { ...@@ -422,7 +446,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_without_tool_call_start_token() { fn test_detect_tool_call_start_deepseek_v3_1_chunk_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}"#; let text = r#"<|tool▁call▁begin|>get_current_weather宽带}"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config); let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(!result); assert!(!result);
} }
...@@ -430,7 +457,10 @@ mod detect_parser_tests { ...@@ -430,7 +457,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token_in_middle() { fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token_in_middle() {
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#; let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config); let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result); assert!(result);
} }
...@@ -438,7 +468,10 @@ mod detect_parser_tests { ...@@ -438,7 +468,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_1_partial_tokens() { fn test_detect_tool_call_start_deepseek_v3_1_partial_tokens() {
// Test partial token detection for streaming scenarios with unicode characters // Test partial token detection for streaming scenarios with unicode characters
let config = ToolCallConfig::deepseek_v3_1().json; let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
// Test various partial prefixes // Test various partial prefixes
assert!( assert!(
......
...@@ -281,7 +281,10 @@ mod tests { ...@@ -281,7 +281,10 @@ mod tests {
```json ```json
{"location": "Paris"} {"location": "Paris"}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -299,7 +302,10 @@ mod tests { ...@@ -299,7 +302,10 @@ mod tests {
```json ```json
{"location": "New York"} {"location": "New York"}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!( assert_eq!(
content, content,
...@@ -317,7 +323,10 @@ mod tests { ...@@ -317,7 +323,10 @@ mod tests {
```json ```json
} }
```<|tool▁call▁end|><|tool▁calls▁end|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some(text.to_string())); assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
...@@ -335,7 +344,10 @@ mod tests { ...@@ -335,7 +344,10 @@ mod tests {
```json ```json
{"location": "Shanghai", "radius": 50} {"location": "Shanghai", "radius": 50}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
...@@ -361,7 +373,10 @@ mod tests { ...@@ -361,7 +373,10 @@ mod tests {
```json ```json
} }
```<|tool▁call▁end|><|tool▁calls▁end|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
...@@ -380,7 +395,10 @@ mod tests { ...@@ -380,7 +395,10 @@ mod tests {
```json ```json
} }
```<|tool▁call▁end|><|tool▁calls▁end|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
...@@ -404,7 +422,10 @@ mod tests { ...@@ -404,7 +422,10 @@ mod tests {
"Summarizing the codebase purpose and "Summarizing the codebase purpose and
functionality"}]} functionality"}]}
```<|tool▁call▁end|><|tool▁calls▁end|>"#; ```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (tool_call_results, normal_content) = let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3(text, &config).unwrap(); parse_tool_calls_deepseek_v3(text, &config).unwrap();
...@@ -454,7 +475,10 @@ mod detect_parser_tests { ...@@ -454,7 +475,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_chunk_with_tool_call_start_token() { fn test_detect_tool_call_start_deepseek_v3_chunk_with_tool_call_start_token() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function宽带}"#; let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function宽带}"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3(text, &config); let result = detect_tool_call_start_deepseek_v3(text, &config);
assert!(result); assert!(result);
} }
...@@ -462,7 +486,10 @@ mod detect_parser_tests { ...@@ -462,7 +486,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_chunk_without_tool_call_start_token() { fn test_detect_tool_call_start_deepseek_v3_chunk_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>function宽带}"#; let text = r#"<|tool▁call▁begin|>function宽带}"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3(text, &config); let result = detect_tool_call_start_deepseek_v3(text, &config);
assert!(!result); assert!(!result);
} }
...@@ -470,7 +497,10 @@ mod detect_parser_tests { ...@@ -470,7 +497,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_chunk_with_tool_call_start_token_in_middle() { fn test_detect_tool_call_start_deepseek_v3_chunk_with_tool_call_start_token_in_middle() {
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>function宽带}"#; let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>function宽带}"#;
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3(text, &config); let result = detect_tool_call_start_deepseek_v3(text, &config);
assert!(result); assert!(result);
} }
...@@ -478,7 +508,10 @@ mod detect_parser_tests { ...@@ -478,7 +508,10 @@ mod detect_parser_tests {
#[test] #[test]
fn test_detect_tool_call_start_deepseek_v3_partial_tokens() { fn test_detect_tool_call_start_deepseek_v3_partial_tokens() {
// Test partial token detection for streaming scenarios with unicode characters // Test partial token detection for streaming scenarios with unicode characters
let config = ToolCallConfig::deepseek_v3().json; let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
// Test various partial prefixes // Test various partial prefixes
assert!( assert!(
......
...@@ -13,7 +13,7 @@ pub mod tools; ...@@ -13,7 +13,7 @@ pub mod tools;
pub mod xml; pub mod xml;
// Re-export main types and functions for convenience // Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType}; pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig};
pub use harmony::parse_tool_calls_harmony_complete; pub use harmony::parse_tool_calls_harmony_complete;
pub use json::try_tool_call_parse_json; pub use json::try_tool_call_parse_json;
pub use parsers::{ pub use parsers::{
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType}; use super::config::{ParserConfig, ToolCallConfig};
use super::harmony::{ use super::harmony::{
detect_tool_call_start_harmony, find_tool_call_end_position_harmony, detect_tool_call_start_harmony, find_tool_call_end_position_harmony,
parse_tool_calls_harmony_complete, parse_tool_calls_harmony_complete,
...@@ -50,25 +50,25 @@ pub async fn try_tool_call_parse( ...@@ -50,25 +50,25 @@ pub async fn try_tool_call_parse(
config: &ToolCallConfig, config: &ToolCallConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Use match statement (Rust's switch statement) to call the appropriate parser // Use match statement (Rust's switch statement) to call the appropriate parser
match config.format { match &config.parser_config {
ToolCallParserType::Json => { ParserConfig::Json(json_config) => {
let (results, normal_content) = try_tool_call_parse_json(message, &config.json)?; let (results, normal_content) = try_tool_call_parse_json(message, json_config)?;
Ok((results, normal_content)) Ok((results, normal_content))
} }
ToolCallParserType::Harmony => { ParserConfig::Harmony(json_config) => {
let (results, normal_content) = let (results, normal_content) =
parse_tool_calls_harmony_complete(message, &config.json).await?; parse_tool_calls_harmony_complete(message, json_config).await?;
Ok((results, normal_content)) Ok((results, normal_content))
} }
ToolCallParserType::Pythonic => { ParserConfig::Pythonic => {
let (results, normal_content) = try_tool_call_parse_pythonic(message)?; let (results, normal_content) = try_tool_call_parse_pythonic(message)?;
Ok((results, normal_content)) Ok((results, normal_content))
} }
ToolCallParserType::Typescript => { ParserConfig::Typescript => {
anyhow::bail!("Typescript parser not implemented"); anyhow::bail!("Typescript parser not implemented");
} }
ToolCallParserType::Xml => { ParserConfig::Xml(xml_config) => {
let (results, normal_content) = try_tool_call_parse_xml(message)?; let (results, normal_content) = try_tool_call_parse_xml(message, xml_config)?;
Ok((results, normal_content)) Ok((results, normal_content))
} }
} }
...@@ -109,16 +109,16 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow:: ...@@ -109,16 +109,16 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::
}; };
match parser_map.get(parser_key) { match parser_map.get(parser_key) {
Some(config) => match config.format { Some(config) => match &config.parser_config {
ToolCallParserType::Json => Ok(detect_tool_call_start_json(chunk, &config.json)), ParserConfig::Json(json_config) => Ok(detect_tool_call_start_json(chunk, json_config)),
ToolCallParserType::Harmony => { ParserConfig::Harmony(json_config) => {
Ok(detect_tool_call_start_harmony(chunk, &config.json, false)) Ok(detect_tool_call_start_harmony(chunk, json_config, false))
} }
ToolCallParserType::Pythonic => Ok(detect_tool_call_start_pythonic(chunk)), ParserConfig::Pythonic => Ok(detect_tool_call_start_pythonic(chunk)),
ToolCallParserType::Typescript => { ParserConfig::Typescript => {
anyhow::bail!("Typescript parser not implemented"); anyhow::bail!("Typescript parser not implemented");
} }
ToolCallParserType::Xml => Ok(detect_tool_call_start_xml(chunk)), ParserConfig::Xml(xml_config) => Ok(detect_tool_call_start_xml(chunk, xml_config)),
}, },
None => anyhow::bail!( None => anyhow::bail!(
"Parser '{}' is not implemented. Available parsers: {:?}", "Parser '{}' is not implemented. Available parsers: {:?}",
...@@ -136,23 +136,25 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi ...@@ -136,23 +136,25 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi
}; };
match parser_map.get(parser_key) { match parser_map.get(parser_key) {
Some(config) => match config.format { Some(config) => match &config.parser_config {
ToolCallParserType::Json => { ParserConfig::Json(json_config) => {
// For "default", use "nemotron_deci" as the effective parser; otherwise, use the provided parser_key // For "default", use "nemotron_deci" as the effective parser; otherwise, use the provided parser_key
let effective_parser = if parser_key == "default" { let effective_parser = if parser_key == "default" {
"nemotron_deci" "nemotron_deci"
} else { } else {
parser_key parser_key
}; };
find_tool_call_end_position_json(chunk, effective_parser, &config.json) find_tool_call_end_position_json(chunk, effective_parser, json_config)
} }
ToolCallParserType::Harmony => find_tool_call_end_position_harmony(chunk, &config.json), ParserConfig::Harmony(json_config) => {
ToolCallParserType::Pythonic => find_tool_call_end_position_pythonic(chunk), find_tool_call_end_position_harmony(chunk, json_config)
ToolCallParserType::Typescript => { }
ParserConfig::Pythonic => find_tool_call_end_position_pythonic(chunk),
ParserConfig::Typescript => {
// Typescript parser not implemented // Typescript parser not implemented
chunk.len() chunk.len()
} }
ToolCallParserType::Xml => find_tool_call_end_position_xml(chunk), ParserConfig::Xml(xml_config) => find_tool_call_end_position_xml(chunk, xml_config),
}, },
None => { None => {
// Unknown parser, return full content length // Unknown parser, return full content length
...@@ -280,12 +282,11 @@ mod tests { ...@@ -280,12 +282,11 @@ mod tests {
let (result, content) = try_tool_call_parse( let (result, content) = try_tool_call_parse(
input, input,
&ToolCallConfig { &ToolCallConfig {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()], tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()], tool_call_end_tokens: vec!["".to_string()],
..Default::default() ..Default::default()
}, }),
}, },
) )
.await .await
...@@ -534,13 +535,12 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -534,13 +535,12 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_ibm_granite_40_tiny_preview_simple() { async fn test_ibm_granite_40_tiny_preview_simple() {
let input = r#"[{"arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}, "name": "get_weather"}]"#; let input = r#"[{"arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}, "name": "get_weather"}]"#;
let config = ToolCallConfig { let config = ToolCallConfig {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec![], tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![], tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()], arguments_keys: vec!["arguments".to_string()],
..Default::default() ..Default::default()
}, }),
}; };
let (result, content) = try_tool_call_parse(input, &config).await.unwrap(); let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
...@@ -946,13 +946,12 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -946,13 +946,12 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
]"#; ]"#;
let config = ToolCallConfig { let config = ToolCallConfig {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec![], tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![], tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()], arguments_keys: vec!["arguments".to_string()],
..Default::default() ..Default::default()
}, }),
}; };
let (result, content) = try_tool_call_parse(input, &config).await.unwrap(); let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
...@@ -969,13 +968,12 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -969,13 +968,12 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_salesforce_llama_xlam_2_8b_fc_r_simple() { async fn test_salesforce_llama_xlam_2_8b_fc_r_simple() {
let input = r#"[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#; let input = r#"[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig { let config = ToolCallConfig {
format: ToolCallParserType::Json, parser_config: ParserConfig::Json(JsonParserConfig {
json: JsonParserConfig {
tool_call_start_tokens: vec![], tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![], tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()], arguments_keys: vec!["arguments".to_string()],
..Default::default() ..Default::default()
}, }),
}; };
let (result, content) = try_tool_call_parse(input, &config).await.unwrap(); let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
...@@ -1314,33 +1312,37 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -1314,33 +1312,37 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
// Test that "fun" is detected as a potential tool call start (for streaming jailing) // Test that "fun" is detected as a potential tool call start (for streaming jailing)
let config = super::get_tool_parser_map().get("phi4").unwrap(); let config = super::get_tool_parser_map().get("phi4").unwrap();
let json_config = match &config.parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
// Test detection of partial tokens // Test detection of partial tokens
use super::super::json::detect_tool_call_start_json; use super::super::json::detect_tool_call_start_json;
assert!( assert!(
detect_tool_call_start_json("fun", &config.json), detect_tool_call_start_json("fun", json_config),
"'fun' should be detected as potential start" "'fun' should be detected as potential start"
); );
assert!( assert!(
detect_tool_call_start_json("f", &config.json), detect_tool_call_start_json("f", json_config),
"'f' should be detected as potential start" "'f' should be detected as potential start"
); );
assert!( assert!(
detect_tool_call_start_json("func", &config.json), detect_tool_call_start_json("func", json_config),
"'func' should be detected as potential start" "'func' should be detected as potential start"
); );
assert!( assert!(
detect_tool_call_start_json("functo", &config.json), detect_tool_call_start_json("functo", json_config),
"'functo' should be detected as potential start" "'functo' should be detected as potential start"
); );
// Test that unrelated text is not detected // Test that unrelated text is not detected
assert!( assert!(
!detect_tool_call_start_json("hello", &config.json), !detect_tool_call_start_json("hello", json_config),
"'hello' should not be detected" "'hello' should not be detected"
); );
assert!( assert!(
!detect_tool_call_start_json("xyz", &config.json), !detect_tool_call_start_json("xyz", json_config),
"'xyz' should not be detected" "'xyz' should not be detected"
); );
} }
......
...@@ -5,25 +5,21 @@ ...@@ -5,25 +5,21 @@
// https://github.com/sgl-project/sglang/blob/44da737770e4bcd9bfa27751f0a0751c9b5c06e1/python/sglang/srt/function_call/qwen3_coder_detector.py // https://github.com/sgl-project/sglang/blob/44da737770e4bcd9bfa27751f0a0751c9b5c06e1/python/sglang/srt/function_call/qwen3_coder_detector.py
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::OnceLock;
use regex::Regex; use regex::Regex;
use uuid::Uuid; use uuid::Uuid;
use super::super::config::XmlParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
/// Check if a chunk contains the start of a xml-style tool call. /// Check if a chunk contains the start of a xml-style tool call.
/// Format: <tool_call><function=name><parameter=foo>...</parameter></function></tool_call> /// Format: <tool_call><function=name><parameter=foo>...</parameter></function></tool_call>
// TODO(2ez4bz): Add a parser config struct that allows parameterizing: pub fn detect_tool_call_start_xml(chunk: &str, config: &XmlParserConfig) -> bool {
// * the tool call start / end tokens
// * the function start / end tokens
// * the parameter start / end tokens
pub fn detect_tool_call_start_xml(chunk: &str) -> bool {
// Check for complete or partial start token. // Check for complete or partial start token.
let start_token = "<tool_call>"; let start_token = &config.tool_call_start_token;
// Check if we have the complete start token. // Check if we have the complete start token.
if chunk.contains(start_token) { if chunk.contains(start_token.as_str()) {
return true; return true;
} }
...@@ -39,10 +35,10 @@ pub fn detect_tool_call_start_xml(chunk: &str) -> bool { ...@@ -39,10 +35,10 @@ pub fn detect_tool_call_start_xml(chunk: &str) -> bool {
/// Find the end position of a Qwen3Coder tool call. /// Find the end position of a Qwen3Coder tool call.
/// Returns the position after </tool_call> or the length of the chunk if not found. /// Returns the position after </tool_call> or the length of the chunk if not found.
pub fn find_tool_call_end_position_xml(chunk: &str) -> usize { pub fn find_tool_call_end_position_xml(chunk: &str, config: &XmlParserConfig) -> usize {
let end_token = "</tool_call>"; let end_token = &config.tool_call_end_token;
if let Some(pos) = chunk.find(end_token) { if let Some(pos) = chunk.find(end_token.as_str()) {
pos + end_token.len() pos + end_token.len()
} else { } else {
chunk.len() chunk.len()
...@@ -54,8 +50,9 @@ pub fn find_tool_call_end_position_xml(chunk: &str) -> usize { ...@@ -54,8 +50,9 @@ pub fn find_tool_call_end_position_xml(chunk: &str) -> usize {
/// Returns (parsed_tool_calls, normal_text_content) /// Returns (parsed_tool_calls, normal_text_content)
pub fn try_tool_call_parse_xml( pub fn try_tool_call_parse_xml(
message: &str, message: &str,
config: &XmlParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let (normal_text, tool_calls) = extract_tool_calls(message)?; let (normal_text, tool_calls) = extract_tool_calls(message, config)?;
let normal_content = if normal_text.is_empty() { let normal_content = if normal_text.is_empty() {
Some("".to_string()) Some("".to_string())
...@@ -67,29 +64,32 @@ pub fn try_tool_call_parse_xml( ...@@ -67,29 +64,32 @@ pub fn try_tool_call_parse_xml(
} }
/// Extract tool calls and normal text from message. /// Extract tool calls and normal text from message.
fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec<ToolCallResponse>)> { fn extract_tool_calls(
text: &str,
config: &XmlParserConfig,
) -> anyhow::Result<(String, Vec<ToolCallResponse>)> {
let mut normal_parts = Vec::new(); let mut normal_parts = Vec::new();
let mut calls = Vec::new(); let mut calls = Vec::new();
let mut cursor = 0; let mut cursor = 0;
let start_token = "<tool_call>"; let start_token = &config.tool_call_start_token;
let end_token = "</tool_call>"; let end_token = &config.tool_call_end_token;
while cursor < text.len() { while cursor < text.len() {
// Find next tool call start. // Find next tool call start.
if let Some(start_pos) = text[cursor..].find(start_token) { if let Some(start_pos) = text[cursor..].find(start_token.as_str()) {
let abs_start = cursor + start_pos; let abs_start = cursor + start_pos;
// Add text before tool call to normal parts. // Add text before tool call to normal parts.
normal_parts.push(&text[cursor..abs_start]); normal_parts.push(&text[cursor..abs_start]);
// Find the corresponding end token. // Find the corresponding end token.
if let Some(end_pos) = text[abs_start..].find(end_token) { if let Some(end_pos) = text[abs_start..].find(end_token.as_str()) {
let abs_end = abs_start + end_pos + end_token.len(); let abs_end = abs_start + end_pos + end_token.len();
let block = &text[abs_start..abs_end]; let block = &text[abs_start..abs_end];
// Parse this tool call block. // Parse this tool call block.
if let Ok(mut parsed_calls) = parse_tool_call_block(block) { if let Ok(mut parsed_calls) = parse_tool_call_block(block, config) {
calls.append(&mut parsed_calls); calls.append(&mut parsed_calls);
} }
...@@ -112,21 +112,24 @@ fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec<ToolCallRespons ...@@ -112,21 +112,24 @@ fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec<ToolCallRespons
/// Parse a single tool call block /// Parse a single tool call block
/// Format: <tool_call><function=name><parameter=key>value</parameter>...</function></tool_call> /// Format: <tool_call><function=name><parameter=key>value</parameter>...</function></tool_call>
fn parse_tool_call_block(block: &str) -> anyhow::Result<Vec<ToolCallResponse>> { fn parse_tool_call_block(
static FUNCTION_REGEX: OnceLock<Regex> = OnceLock::new(); block: &str,
static PARAMETER_REGEX: OnceLock<Regex> = OnceLock::new(); config: &XmlParserConfig,
) -> anyhow::Result<Vec<ToolCallResponse>> {
let function_regex = FUNCTION_REGEX.get_or_init(|| { // Build regex patterns based on config
// Match <function=name>content</function> or partial <function=name>content let function_start = regex::escape(&config.function_start_token);
// (?s) makes . match newlines let function_end = regex::escape(&config.function_end_token);
Regex::new(r"(?s)<function=([^>]+)>(.*?)(?:</function>|$)").unwrap() let parameter_start = regex::escape(&config.parameter_start_token);
}); let parameter_end = regex::escape(&config.parameter_end_token);
let parameter_regex = PARAMETER_REGEX.get_or_init(|| { let function_pattern = format!(r"(?s){}([^>]+)>(.*?)(?:{}|$)", function_start, function_end);
// Match <parameter=key>value</parameter> or partial <parameter=key>value let parameter_pattern = format!(
// (?s) makes . match newlines r"(?s){}([^>]+)>(.*?)(?:{}|$)",
Regex::new(r"(?s)<parameter=([^>]+)>(.*?)(?:</parameter>|$)").unwrap() parameter_start, parameter_end
}); );
let function_regex = Regex::new(&function_pattern)?;
let parameter_regex = Regex::new(&parameter_pattern)?;
let mut results = Vec::new(); let mut results = Vec::new();
...@@ -218,23 +221,25 @@ mod tests { ...@@ -218,23 +221,25 @@ mod tests {
#[test] #[test]
fn test_detect_tool_call_start() { fn test_detect_tool_call_start() {
assert!(detect_tool_call_start_xml("<tool_call>")); let config = XmlParserConfig::default();
assert!(detect_tool_call_start_xml("text <tool_call>")); assert!(detect_tool_call_start_xml("<tool_call>", &config));
assert!(detect_tool_call_start_xml("<tool_c")); // Partial match assert!(detect_tool_call_start_xml("text <tool_call>", &config));
assert!(detect_tool_call_start_xml("<")); // Partial match assert!(detect_tool_call_start_xml("<tool_c", &config)); // Partial match
assert!(!detect_tool_call_start_xml("no tool call here")); assert!(detect_tool_call_start_xml("<", &config)); // Partial match
assert!(!detect_tool_call_start_xml("toolcall")); assert!(!detect_tool_call_start_xml("no tool call here", &config));
assert!(!detect_tool_call_start_xml("toolcall", &config));
} }
#[test] #[test]
fn test_find_tool_call_end_position() { fn test_find_tool_call_end_position() {
let config = XmlParserConfig::default();
let text = "<tool_call><function=test></function></tool_call>more text"; let text = "<tool_call><function=test></function></tool_call>more text";
let pos = find_tool_call_end_position_xml(text); let pos = find_tool_call_end_position_xml(text, &config);
assert_eq!(pos, 49); // Position after </tool_call> assert_eq!(pos, 49); // Position after </tool_call>
assert_eq!(&text[pos..], "more text"); assert_eq!(&text[pos..], "more text");
let text_no_end = "<tool_call><function=test>"; let text_no_end = "<tool_call><function=test>";
let pos = find_tool_call_end_position_xml(text_no_end); let pos = find_tool_call_end_position_xml(text_no_end, &config);
assert_eq!(pos, text_no_end.len()); assert_eq!(pos, text_no_end.len());
} }
...@@ -274,7 +279,7 @@ pwd && ls ...@@ -274,7 +279,7 @@ pwd && ls
</function> </function>
</tool_call>"#; </tool_call>"#;
let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash"); assert_eq!(calls[0].function.name, "execute_bash");
assert_eq!(normal, Some("".to_string())); assert_eq!(normal, Some("".to_string()));
...@@ -299,7 +304,7 @@ fahrenheit ...@@ -299,7 +304,7 @@ fahrenheit
</function> </function>
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather"); assert_eq!(calls[0].function.name, "get_weather");
...@@ -319,7 +324,7 @@ Dallas ...@@ -319,7 +324,7 @@ Dallas
</function> </function>
</tool_call> Let me check that for you."#; </tool_call> Let me check that for you."#;
let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather"); assert_eq!(calls[0].function.name, "get_weather");
assert_eq!( assert_eq!(
...@@ -345,7 +350,7 @@ Orlando ...@@ -345,7 +350,7 @@ Orlando
</function> </function>
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 2); assert_eq!(calls.len(), 2);
assert_eq!(calls[0].function.name, "get_weather"); assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(calls[1].function.name, "get_weather"); assert_eq!(calls[1].function.name, "get_weather");
...@@ -366,7 +371,7 @@ Orlando ...@@ -366,7 +371,7 @@ Orlando
</function> </function>
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
...@@ -378,7 +383,7 @@ Orlando ...@@ -378,7 +383,7 @@ Orlando
#[test] #[test]
fn test_parse_no_tool_calls() { fn test_parse_no_tool_calls() {
let input = "This is just normal text without any tool calls."; let input = "This is just normal text without any tool calls.";
let (calls, normal) = try_tool_call_parse_xml(input).unwrap(); let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 0); assert_eq!(calls.len(), 0);
assert_eq!(normal, Some(input.to_string())); assert_eq!(normal, Some(input.to_string()));
} }
...@@ -392,7 +397,7 @@ value ...@@ -392,7 +397,7 @@ value
</tool_call>"#; </tool_call>"#;
// Should handle gracefully - might parse or return empty // Should handle gracefully - might parse or return empty
let result = try_tool_call_parse_xml(input); let result = try_tool_call_parse_xml(input, &XmlParserConfig::default());
assert!(result.is_ok()); assert!(result.is_ok());
} }
...@@ -405,7 +410,7 @@ ls -la ...@@ -405,7 +410,7 @@ ls -la
</function> </function>
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash"); assert_eq!(calls[0].function.name, "execute_bash");
...@@ -422,7 +427,7 @@ Boston ...@@ -422,7 +427,7 @@ Boston
</parameter> </parameter>
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather"); assert_eq!(calls[0].function.name, "get_weather");
...@@ -438,7 +443,7 @@ Boston ...@@ -438,7 +443,7 @@ Boston
SELECT * FROM users SELECT * FROM users
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "run_query"); assert_eq!(calls[0].function.name, "run_query");
...@@ -458,7 +463,7 @@ rust programming ...@@ -458,7 +463,7 @@ rust programming
</function> </function>
</tool_call>"#; </tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap(); let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search"); assert_eq!(calls[0].function.name, "search");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment