Unverified Commit 44e8600a authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

refactor: New config types for tool calls (#4575)

* Why?

We would like the ability to configure different parser types. Prior to
this commit, only the JSON parser could be configured.

* What?

This commit refactors the tool parser config in the following ways:
- the `format` and `json` fields of `ToolParserConfig` are merged into
  a single `config` field that is a "discriminated union" type. Each
  parser type can declare its own configuration options.
- a `XmlParserConfig` is defined with a default factory method that
  corresponds to the Qwen3 coder configuration.
- affected calls and tests are adjusted.
parent 262cce76
......@@ -895,14 +895,14 @@ impl JailedStreamBuilder {
if let Some(config) = parser_map.get(parser_name.as_str()) {
// Auto-populate start sequences if none configured
if self.jail_start_sequences.is_empty() {
self.jail_start_sequences = config.json.tool_call_start_tokens.clone();
self.jail_start_sequences = config.parser_config.tool_call_start_tokens();
}
// Auto-populate end sequences if none configured
if self.jail_end_sequences.is_empty() {
self.jail_end_sequences = config
.json
.tool_call_end_tokens
.parser_config
.tool_call_end_tokens()
.iter()
.filter(|&s| !s.is_empty())
.cloned()
......@@ -922,7 +922,7 @@ impl JailedStreamBuilder {
let parser_map = get_tool_parser_map();
if let Some(config) = parser_map.get(parser_name.as_str()) {
// Add start tokens from the parser config
all_patterns.extend(config.json.tool_call_start_tokens.clone());
all_patterns.extend(config.parser_config.tool_call_start_tokens());
}
}
......
......@@ -1989,6 +1989,239 @@ mod tests {
}
}
#[tokio::test]
async fn test_jailed_stream_qwen3_coder_parser() {
// Input:
// "I'll call a function. "
// + "<tool_call><function=get_weather><parameter=location>San Francisco</parameter><parameter=unit>celsius</parameter></function></tool_call>"
// + " Done."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("I'll call a function. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
create_mock_response_chunk(
"<parameter=location>San Francisco</parameter>".to_string(),
0,
),
create_mock_response_chunk("<parameter=unit>celsius</parameter>".to_string(), 0),
create_mock_response_chunk("</function>".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" Done.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()].
test_utils::assert_content(&results[0], "I'll call a function. ");
test_utils::assert_tool_call(
&results[1],
"get_weather",
serde_json::json!({"location": "San Francisco", "unit": "celsius"}),
);
test_utils::assert_content(&results[2], " Done.");
// Verify content reconstruction excludes tool calls.
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "I'll call a function. Done.");
}
#[tokio::test]
async fn test_jailed_stream_qwen3_coder_multiple_params() {
let chunks = vec![
create_mock_response_chunk("Let me search for that. ".to_string(), 0),
create_mock_response_chunk(
"<tool_call><function=web_search><parameter=query>Rust programming</parameter><parameter=max_results>10</parameter><parameter=filter>recent</parameter></function></tool_call>".to_string(),
0,
),
create_mock_response_chunk(" Searching now.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert_eq!(results.len(), 3, "Should have 3 chunks");
test_utils::assert_content(&results[0], "Let me search for that. ");
test_utils::assert_tool_call(
&results[1],
"web_search",
serde_json::json!({
"query": "Rust programming",
"max_results": 10,
"filter": "recent"
}),
);
test_utils::assert_content(&results[2], " Searching now.");
}
#[tokio::test]
async fn test_jailed_stream_xml_parser_config_tokens_auto_population() {
// Tests that parser config tokens are auto-populated when using `.tool_call_parser()`.
// This verifies the jail system reads `tool_call_start_token` and `tool_call_end_token`
// from the `qwen3_coder` parser config.
let chunks = vec![
create_mock_response_chunk("Before tool call. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0), // Default qwen3_coder token
create_mock_response_chunk("<function=get_weather>".to_string(), 0),
create_mock_response_chunk("<parameter=city>Seattle</parameter>".to_string(), 0),
create_mock_response_chunk("</function>".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0), // Default qwen3_coder token
create_mock_response_chunk(" After tool call.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream using ONLY `.tool_call_parser()`.
// This should auto-populate jail sequences from the qwen3_coder config
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
test_utils::assert_content(&results[0], "Before tool call. ");
test_utils::assert_tool_call(
&results[1],
"get_weather",
serde_json::json!({"city": "Seattle"}),
);
test_utils::assert_content(&results[2], " After tool call.");
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "Before tool call. After tool call.");
}
#[tokio::test]
async fn test_jailed_stream_xml_manual_sequences_prevent_auto_population() {
// Tests that manually setting jail sequences prevents auto-population.
// This verifies the builder respects manual configuration over auto-population.
//
// When custom sequences are set, the default parser tokens (<tool_call>) should
// NOT trigger jailing and should pass through as regular content.
let chunks = vec![
create_mock_response_chunk("Text with ".to_string(), 0),
// Default qwen3_coder token - should NOT trigger jailing.
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk("should not jail".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" because custom ".to_string(), 0),
// Custom marker - this SHOULD trigger jailing since we register it below.
create_mock_response_chunk("[[START]]".to_string(), 0),
create_mock_response_chunk("jailed content".to_string(), 0),
create_mock_response_chunk("[[END]]".to_string(), 0),
create_mock_response_chunk(" text.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Set custom jail sequences - this should prevent auto-population.
// The default <tool_call> tokens should NOT trigger jailing.
let jail = JailedStream::builder()
.jail_start_sequence("[[START]]")
.jail_end_sequence("[[END]]")
.tool_call_parser("qwen3_coder")
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
// The exact number of chunks depends on emission mode (packed vs single-choice-per-chunk)
// but we can verify the key behaviors:
// 1. Default <tool_call> tokens pass through as content (not jailed)
// 2. Custom [[START]]/[[END]] markers trigger jailing
// 3. No tool calls are extracted (because jailed content isn't valid XML)
// Find chunk(s) containing the default tokens that passed through.
let default_token_chunks: Vec<_> = results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
})
.filter(|content| {
content.contains("<tool_call>") || content.contains("should not jail")
})
.collect();
assert!(
!default_token_chunks.is_empty(),
"Default <tool_call> should pass through as content when manual sequences are set"
);
// Find chunk containing the jailed content that was released.
let jailed_chunk = results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
})
.find(|content| content.contains("[[START]]") && content.contains("jailed content"));
assert!(
jailed_chunk.is_some(),
"Custom markers should trigger jailing and accumulated content should be released"
);
// Since the custom markers include non-XML content, the parser should not extract tool calls.
// The accumulated content "[[START]]jailed content[[END]]", although compatible with the
// way we configured `jail` above, is not consistent with what `qwen_coder` expects, and
// there is (at time of writing) no way to pass a parser instance - only a string that
// internally gets mapped to default way of instantiating a particular parser.
let tool_call_count = results
.iter()
.filter(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.tool_calls.as_ref())
.map(|tc| !tc.is_empty())
.unwrap_or(false)
})
.count();
assert_eq!(
tool_call_count, 0,
"Should have 0 tool calls because jailed content doesn't match XML format"
);
// Verify content reconstruction - all original content should be preserved.
let reconstructed = test_utils::reconstruct_content(&results);
assert!(
reconstructed.contains("<tool_call>") && reconstructed.contains("should not jail"),
"Reconstructed content should include default tokens that passed through"
);
assert!(
reconstructed.contains("[[START]]") && reconstructed.contains("jailed content"),
"Reconstructed content should include jailed content with custom markers"
);
}
#[tokio::test]
async fn test_jailed_stream_mistral_false_positive_curly() {
// Curly brace in normal text should not trigger tool call detection for mistral
......
......@@ -3,20 +3,6 @@
use super::json::JsonParserType;
/// Represents the format type for tool calls
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum ToolCallParserType {
/// JSON format: `{"name": "function", "arguments": {...}}`
Json,
Pythonic,
Harmony,
/// <function_call>```typescript
/// functions.get_current_weather({"location": "Shanghai"})
/// ```
Typescript,
Xml,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
/// Start token for individual tool calls (e.g., "<TOOLCALL>")
......@@ -54,21 +40,83 @@ impl Default for JsonParserConfig {
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct XmlParserConfig {
/// Start token for individual tool calls (e.g., "<tool_call>")
pub tool_call_start_token: String,
/// End token for individual tool calls (e.g., "</tool_call>")
pub tool_call_end_token: String,
/// Start token for function name (e.g., "<function=")
pub function_start_token: String,
/// End token for function (e.g., "</function>")
pub function_end_token: String,
/// Start token for parameter (e.g., "<parameter=")
pub parameter_start_token: String,
/// End token for parameter (e.g., "</parameter>")
pub parameter_end_token: String,
}
impl Default for XmlParserConfig {
fn default() -> Self {
Self {
tool_call_start_token: "<tool_call>".to_string(),
tool_call_end_token: "</tool_call>".to_string(),
function_start_token: "<function=".to_string(),
function_end_token: "</function>".to_string(),
parameter_start_token: "<parameter=".to_string(),
parameter_end_token: "</parameter>".to_string(),
}
}
}
/// Parser-specific configuration
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ParserConfig {
Json(JsonParserConfig),
Xml(XmlParserConfig),
Pythonic,
Harmony(JsonParserConfig),
Typescript,
}
impl ParserConfig {
/// Get the tool call start tokens for this parser configuration
/// Returns a vector of start tokens that indicate the beginning of a tool call
pub fn tool_call_start_tokens(&self) -> Vec<String> {
match self {
ParserConfig::Json(config) => config.tool_call_start_tokens.clone(),
ParserConfig::Harmony(config) => config.tool_call_start_tokens.clone(),
ParserConfig::Xml(config) => vec![config.tool_call_start_token.clone()],
ParserConfig::Pythonic => vec![],
ParserConfig::Typescript => vec![],
}
}
/// Get the tool call end tokens for this parser configuration
/// Returns a vector of end tokens that indicate the end of a tool call
pub fn tool_call_end_tokens(&self) -> Vec<String> {
match self {
ParserConfig::Json(config) => config.tool_call_end_tokens.clone(),
ParserConfig::Harmony(config) => config.tool_call_end_tokens.clone(),
ParserConfig::Xml(config) => vec![config.tool_call_end_token.clone()],
ParserConfig::Pythonic => vec![],
ParserConfig::Typescript => vec![],
}
}
}
/// Configuration for parsing tool calls with different formats
// TODO(2ez4bz): refactor to allow other parser configs than `JsonParserConfig`.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
/// Parser-specific configuration.
pub parser_config: ParserConfig,
}
impl Default for ToolCallConfig {
fn default() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig::default(),
parser_config: ParserConfig::Json(JsonParserConfig::default()),
}
}
}
......@@ -78,12 +126,11 @@ impl ToolCallConfig {
/// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
pub fn hermes() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
},
}),
}
}
......@@ -91,12 +138,11 @@ impl ToolCallConfig {
/// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
pub fn nemotron_deci() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default()
},
}),
}
}
......@@ -104,52 +150,47 @@ impl ToolCallConfig {
// <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
// or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}),
}
}
pub fn mistral() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()],
..Default::default()
},
}),
}
}
pub fn phi4() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}),
}
}
pub fn pythonic() -> Self {
Self {
format: ToolCallParserType::Pythonic,
json: JsonParserConfig::default(), // This is noop here, but we keep it for consistency
parser_config: ParserConfig::Pythonic,
}
}
pub fn harmony() -> Self {
Self {
format: ToolCallParserType::Harmony,
json: JsonParserConfig {
parser_config: ParserConfig::Harmony(JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
},
}),
}
}
......@@ -161,8 +202,7 @@ impl ToolCallConfig {
// so the tool parser can properly consume all tool call tokens.
// https://huggingface.co/deepseek-ai/DeepSeek-V3.1#toolcall
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec![
"<|tool▁calls▁begin|>".to_string(),
// "<|tool▁call▁begin|>".to_string(),
......@@ -174,7 +214,7 @@ impl ToolCallConfig {
tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
parser_type: JsonParserType::DeepseekV31,
..Default::default()
},
}),
}
}
......@@ -183,22 +223,20 @@ impl ToolCallConfig {
// <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
// There are some differences between DeepSeek V3 and DeepSeek V3.1
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
tool_call_separator_tokens: vec!["<|tool▁sep|>".to_string()],
parser_type: JsonParserType::DeepseekV3,
..Default::default()
},
}),
}
}
pub fn qwen3_coder() -> Self {
// <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
Self {
format: ToolCallParserType::Xml,
json: JsonParserConfig::default(), // Not used for qwen3_coder but kept for consistency.
parser_config: ParserConfig::Xml(XmlParserConfig::default()),
}
}
}
......@@ -271,7 +271,10 @@ mod tests {
#[test]
fn test_parse_tool_calls_deepseek_v3_1_basic() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
......@@ -286,7 +289,10 @@ mod tests {
#[test]
fn test_parse_tool_calls_deepseek_v3_1_with_normal_text() {
let text = r#"The following tool call retrieves weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "New York"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(
content,
......@@ -301,7 +307,10 @@ mod tests {
#[test]
fn test_parse_tool_calls_deepseek_v3_1_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
......@@ -310,7 +319,10 @@ mod tests {
#[test]
fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_multiple_args() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Berlin", "units": "metric"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast<|tool▁sep|>{"location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality<|tool▁sep|>{"location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
......@@ -333,7 +345,10 @@ mod tests {
fn test_parse_tool_calls_deepseek_v3_1_with_invalid_json() {
// Everything is normal text in case of invalid json
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
......@@ -343,7 +358,10 @@ mod tests {
fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_normal_text() {
// Everything is normal text in case of invalid json
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast宽带}{location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality宽带}{location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
......@@ -364,7 +382,10 @@ mod tests {
"Summarize the codebase purpose and functionality", "status": "pending", "activeForm":
"Summarizing the codebase purpose and
functionality"}]}<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
......@@ -414,7 +435,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result);
}
......@@ -422,7 +446,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(!result);
}
......@@ -430,7 +457,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token_in_middle() {
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result);
}
......@@ -438,7 +468,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_partial_tokens() {
// Test partial token detection for streaming scenarios with unicode characters
let config = ToolCallConfig::deepseek_v3_1().json;
let config = match ToolCallConfig::deepseek_v3_1().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
// Test various partial prefixes
assert!(
......
......@@ -281,7 +281,10 @@ mod tests {
```json
{"location": "Paris"}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
......@@ -299,7 +302,10 @@ mod tests {
```json
{"location": "New York"}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(
content,
......@@ -317,7 +323,10 @@ mod tests {
```json
}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
......@@ -335,7 +344,10 @@ mod tests {
```json
{"location": "Shanghai", "radius": 50}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
......@@ -361,7 +373,10 @@ mod tests {
```json
}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
......@@ -380,7 +395,10 @@ mod tests {
```json
}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
......@@ -404,7 +422,10 @@ mod tests {
"Summarizing the codebase purpose and
functionality"}]}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3(text, &config).unwrap();
......@@ -454,7 +475,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_chunk_with_tool_call_start_token() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function宽带}"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3(text, &config);
assert!(result);
}
......@@ -462,7 +486,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_chunk_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>function宽带}"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3(text, &config);
assert!(!result);
}
......@@ -470,7 +497,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_chunk_with_tool_call_start_token_in_middle() {
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>function宽带}"#;
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let result = detect_tool_call_start_deepseek_v3(text, &config);
assert!(result);
}
......@@ -478,7 +508,10 @@ mod detect_parser_tests {
#[test]
fn test_detect_tool_call_start_deepseek_v3_partial_tokens() {
// Test partial token detection for streaming scenarios with unicode characters
let config = ToolCallConfig::deepseek_v3().json;
let config = match ToolCallConfig::deepseek_v3().parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
// Test various partial prefixes
assert!(
......
......@@ -13,7 +13,7 @@ pub mod tools;
pub mod xml;
// Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig};
pub use harmony::parse_tool_calls_harmony_complete;
pub use json::try_tool_call_parse_json;
pub use parsers::{
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType};
use super::config::{ParserConfig, ToolCallConfig};
use super::harmony::{
detect_tool_call_start_harmony, find_tool_call_end_position_harmony,
parse_tool_calls_harmony_complete,
......@@ -50,25 +50,25 @@ pub async fn try_tool_call_parse(
config: &ToolCallConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Use match statement (Rust's switch statement) to call the appropriate parser
match config.format {
ToolCallParserType::Json => {
let (results, normal_content) = try_tool_call_parse_json(message, &config.json)?;
match &config.parser_config {
ParserConfig::Json(json_config) => {
let (results, normal_content) = try_tool_call_parse_json(message, json_config)?;
Ok((results, normal_content))
}
ToolCallParserType::Harmony => {
ParserConfig::Harmony(json_config) => {
let (results, normal_content) =
parse_tool_calls_harmony_complete(message, &config.json).await?;
parse_tool_calls_harmony_complete(message, json_config).await?;
Ok((results, normal_content))
}
ToolCallParserType::Pythonic => {
ParserConfig::Pythonic => {
let (results, normal_content) = try_tool_call_parse_pythonic(message)?;
Ok((results, normal_content))
}
ToolCallParserType::Typescript => {
ParserConfig::Typescript => {
anyhow::bail!("Typescript parser not implemented");
}
ToolCallParserType::Xml => {
let (results, normal_content) = try_tool_call_parse_xml(message)?;
ParserConfig::Xml(xml_config) => {
let (results, normal_content) = try_tool_call_parse_xml(message, xml_config)?;
Ok((results, normal_content))
}
}
......@@ -109,16 +109,16 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::
};
match parser_map.get(parser_key) {
Some(config) => match config.format {
ToolCallParserType::Json => Ok(detect_tool_call_start_json(chunk, &config.json)),
ToolCallParserType::Harmony => {
Ok(detect_tool_call_start_harmony(chunk, &config.json, false))
Some(config) => match &config.parser_config {
ParserConfig::Json(json_config) => Ok(detect_tool_call_start_json(chunk, json_config)),
ParserConfig::Harmony(json_config) => {
Ok(detect_tool_call_start_harmony(chunk, json_config, false))
}
ToolCallParserType::Pythonic => Ok(detect_tool_call_start_pythonic(chunk)),
ToolCallParserType::Typescript => {
ParserConfig::Pythonic => Ok(detect_tool_call_start_pythonic(chunk)),
ParserConfig::Typescript => {
anyhow::bail!("Typescript parser not implemented");
}
ToolCallParserType::Xml => Ok(detect_tool_call_start_xml(chunk)),
ParserConfig::Xml(xml_config) => Ok(detect_tool_call_start_xml(chunk, xml_config)),
},
None => anyhow::bail!(
"Parser '{}' is not implemented. Available parsers: {:?}",
......@@ -136,23 +136,25 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi
};
match parser_map.get(parser_key) {
Some(config) => match config.format {
ToolCallParserType::Json => {
Some(config) => match &config.parser_config {
ParserConfig::Json(json_config) => {
// For "default", use "nemotron_deci" as the effective parser; otherwise, use the provided parser_key
let effective_parser = if parser_key == "default" {
"nemotron_deci"
} else {
parser_key
};
find_tool_call_end_position_json(chunk, effective_parser, &config.json)
find_tool_call_end_position_json(chunk, effective_parser, json_config)
}
ToolCallParserType::Harmony => find_tool_call_end_position_harmony(chunk, &config.json),
ToolCallParserType::Pythonic => find_tool_call_end_position_pythonic(chunk),
ToolCallParserType::Typescript => {
ParserConfig::Harmony(json_config) => {
find_tool_call_end_position_harmony(chunk, json_config)
}
ParserConfig::Pythonic => find_tool_call_end_position_pythonic(chunk),
ParserConfig::Typescript => {
// Typescript parser not implemented
chunk.len()
}
ToolCallParserType::Xml => find_tool_call_end_position_xml(chunk),
ParserConfig::Xml(xml_config) => find_tool_call_end_position_xml(chunk, xml_config),
},
None => {
// Unknown parser, return full content length
......@@ -280,12 +282,11 @@ mod tests {
let (result, content) = try_tool_call_parse(
input,
&ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}),
},
)
.await
......@@ -534,13 +535,12 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_ibm_granite_40_tiny_preview_simple() {
let input = r#"[{"arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}, "name": "get_weather"}]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
}),
};
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -946,13 +946,12 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
}),
};
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -969,13 +968,12 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_salesforce_llama_xlam_2_8b_fc_r_simple() {
let input = r#"[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
parser_config: ParserConfig::Json(JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
}),
};
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1314,33 +1312,37 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
// Test that "fun" is detected as a potential tool call start (for streaming jailing)
let config = super::get_tool_parser_map().get("phi4").unwrap();
let json_config = match &config.parser_config {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
// Test detection of partial tokens
use super::super::json::detect_tool_call_start_json;
assert!(
detect_tool_call_start_json("fun", &config.json),
detect_tool_call_start_json("fun", json_config),
"'fun' should be detected as potential start"
);
assert!(
detect_tool_call_start_json("f", &config.json),
detect_tool_call_start_json("f", json_config),
"'f' should be detected as potential start"
);
assert!(
detect_tool_call_start_json("func", &config.json),
detect_tool_call_start_json("func", json_config),
"'func' should be detected as potential start"
);
assert!(
detect_tool_call_start_json("functo", &config.json),
detect_tool_call_start_json("functo", json_config),
"'functo' should be detected as potential start"
);
// Test that unrelated text is not detected
assert!(
!detect_tool_call_start_json("hello", &config.json),
!detect_tool_call_start_json("hello", json_config),
"'hello' should not be detected"
);
assert!(
!detect_tool_call_start_json("xyz", &config.json),
!detect_tool_call_start_json("xyz", json_config),
"'xyz' should not be detected"
);
}
......
......@@ -5,25 +5,21 @@
// https://github.com/sgl-project/sglang/blob/44da737770e4bcd9bfa27751f0a0751c9b5c06e1/python/sglang/srt/function_call/qwen3_coder_detector.py
use std::collections::HashMap;
use std::sync::OnceLock;
use regex::Regex;
use uuid::Uuid;
use super::super::config::XmlParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
/// Check if a chunk contains the start of a xml-style tool call.
/// Format: <tool_call><function=name><parameter=foo>...</parameter></function></tool_call>
// TODO(2ez4bz): Add a parser config struct that allows parameterizing:
// * the tool call start / end tokens
// * the function start / end tokens
// * the parameter start / end tokens
pub fn detect_tool_call_start_xml(chunk: &str) -> bool {
pub fn detect_tool_call_start_xml(chunk: &str, config: &XmlParserConfig) -> bool {
// Check for complete or partial start token.
let start_token = "<tool_call>";
let start_token = &config.tool_call_start_token;
// Check if we have the complete start token.
if chunk.contains(start_token) {
if chunk.contains(start_token.as_str()) {
return true;
}
......@@ -39,10 +35,10 @@ pub fn detect_tool_call_start_xml(chunk: &str) -> bool {
/// Find the end position of a Qwen3Coder tool call.
/// Returns the position after </tool_call> or the length of the chunk if not found.
pub fn find_tool_call_end_position_xml(chunk: &str) -> usize {
let end_token = "</tool_call>";
pub fn find_tool_call_end_position_xml(chunk: &str, config: &XmlParserConfig) -> usize {
let end_token = &config.tool_call_end_token;
if let Some(pos) = chunk.find(end_token) {
if let Some(pos) = chunk.find(end_token.as_str()) {
pos + end_token.len()
} else {
chunk.len()
......@@ -54,8 +50,9 @@ pub fn find_tool_call_end_position_xml(chunk: &str) -> usize {
/// Returns (parsed_tool_calls, normal_text_content)
pub fn try_tool_call_parse_xml(
message: &str,
config: &XmlParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let (normal_text, tool_calls) = extract_tool_calls(message)?;
let (normal_text, tool_calls) = extract_tool_calls(message, config)?;
let normal_content = if normal_text.is_empty() {
Some("".to_string())
......@@ -67,29 +64,32 @@ pub fn try_tool_call_parse_xml(
}
/// Extract tool calls and normal text from message.
fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec<ToolCallResponse>)> {
fn extract_tool_calls(
text: &str,
config: &XmlParserConfig,
) -> anyhow::Result<(String, Vec<ToolCallResponse>)> {
let mut normal_parts = Vec::new();
let mut calls = Vec::new();
let mut cursor = 0;
let start_token = "<tool_call>";
let end_token = "</tool_call>";
let start_token = &config.tool_call_start_token;
let end_token = &config.tool_call_end_token;
while cursor < text.len() {
// Find next tool call start.
if let Some(start_pos) = text[cursor..].find(start_token) {
if let Some(start_pos) = text[cursor..].find(start_token.as_str()) {
let abs_start = cursor + start_pos;
// Add text before tool call to normal parts.
normal_parts.push(&text[cursor..abs_start]);
// Find the corresponding end token.
if let Some(end_pos) = text[abs_start..].find(end_token) {
if let Some(end_pos) = text[abs_start..].find(end_token.as_str()) {
let abs_end = abs_start + end_pos + end_token.len();
let block = &text[abs_start..abs_end];
// Parse this tool call block.
if let Ok(mut parsed_calls) = parse_tool_call_block(block) {
if let Ok(mut parsed_calls) = parse_tool_call_block(block, config) {
calls.append(&mut parsed_calls);
}
......@@ -112,21 +112,24 @@ fn extract_tool_calls(text: &str) -> anyhow::Result<(String, Vec<ToolCallRespons
/// Parse a single tool call block
/// Format: <tool_call><function=name><parameter=key>value</parameter>...</function></tool_call>
fn parse_tool_call_block(block: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
static FUNCTION_REGEX: OnceLock<Regex> = OnceLock::new();
static PARAMETER_REGEX: OnceLock<Regex> = OnceLock::new();
let function_regex = FUNCTION_REGEX.get_or_init(|| {
// Match <function=name>content</function> or partial <function=name>content
// (?s) makes . match newlines
Regex::new(r"(?s)<function=([^>]+)>(.*?)(?:</function>|$)").unwrap()
});
let parameter_regex = PARAMETER_REGEX.get_or_init(|| {
// Match <parameter=key>value</parameter> or partial <parameter=key>value
// (?s) makes . match newlines
Regex::new(r"(?s)<parameter=([^>]+)>(.*?)(?:</parameter>|$)").unwrap()
});
fn parse_tool_call_block(
block: &str,
config: &XmlParserConfig,
) -> anyhow::Result<Vec<ToolCallResponse>> {
// Build regex patterns based on config
let function_start = regex::escape(&config.function_start_token);
let function_end = regex::escape(&config.function_end_token);
let parameter_start = regex::escape(&config.parameter_start_token);
let parameter_end = regex::escape(&config.parameter_end_token);
let function_pattern = format!(r"(?s){}([^>]+)>(.*?)(?:{}|$)", function_start, function_end);
let parameter_pattern = format!(
r"(?s){}([^>]+)>(.*?)(?:{}|$)",
parameter_start, parameter_end
);
let function_regex = Regex::new(&function_pattern)?;
let parameter_regex = Regex::new(&parameter_pattern)?;
let mut results = Vec::new();
......@@ -218,23 +221,25 @@ mod tests {
#[test]
fn test_detect_tool_call_start() {
assert!(detect_tool_call_start_xml("<tool_call>"));
assert!(detect_tool_call_start_xml("text <tool_call>"));
assert!(detect_tool_call_start_xml("<tool_c")); // Partial match
assert!(detect_tool_call_start_xml("<")); // Partial match
assert!(!detect_tool_call_start_xml("no tool call here"));
assert!(!detect_tool_call_start_xml("toolcall"));
let config = XmlParserConfig::default();
assert!(detect_tool_call_start_xml("<tool_call>", &config));
assert!(detect_tool_call_start_xml("text <tool_call>", &config));
assert!(detect_tool_call_start_xml("<tool_c", &config)); // Partial match
assert!(detect_tool_call_start_xml("<", &config)); // Partial match
assert!(!detect_tool_call_start_xml("no tool call here", &config));
assert!(!detect_tool_call_start_xml("toolcall", &config));
}
#[test]
fn test_find_tool_call_end_position() {
let config = XmlParserConfig::default();
let text = "<tool_call><function=test></function></tool_call>more text";
let pos = find_tool_call_end_position_xml(text);
let pos = find_tool_call_end_position_xml(text, &config);
assert_eq!(pos, 49); // Position after </tool_call>
assert_eq!(&text[pos..], "more text");
let text_no_end = "<tool_call><function=test>";
let pos = find_tool_call_end_position_xml(text_no_end);
let pos = find_tool_call_end_position_xml(text_no_end, &config);
assert_eq!(pos, text_no_end.len());
}
......@@ -274,7 +279,7 @@ pwd && ls
</function>
</tool_call>"#;
let (calls, normal) = try_tool_call_parse_xml(input).unwrap();
let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash");
assert_eq!(normal, Some("".to_string()));
......@@ -299,7 +304,7 @@ fahrenheit
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
......@@ -319,7 +324,7 @@ Dallas
</function>
</tool_call> Let me check that for you."#;
let (calls, normal) = try_tool_call_parse_xml(input).unwrap();
let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(
......@@ -345,7 +350,7 @@ Orlando
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(calls[1].function.name, "get_weather");
......@@ -366,7 +371,7 @@ Orlando
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
......@@ -378,7 +383,7 @@ Orlando
#[test]
fn test_parse_no_tool_calls() {
let input = "This is just normal text without any tool calls.";
let (calls, normal) = try_tool_call_parse_xml(input).unwrap();
let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 0);
assert_eq!(normal, Some(input.to_string()));
}
......@@ -392,7 +397,7 @@ value
</tool_call>"#;
// Should handle gracefully - might parse or return empty
let result = try_tool_call_parse_xml(input);
let result = try_tool_call_parse_xml(input, &XmlParserConfig::default());
assert!(result.is_ok());
}
......@@ -405,7 +410,7 @@ ls -la
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash");
......@@ -422,7 +427,7 @@ Boston
</parameter>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
......@@ -438,7 +443,7 @@ Boston
SELECT * FROM users
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "run_query");
......@@ -458,7 +463,7 @@ rust programming
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment