Unverified Commit 91700375 authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix(parsers): parallel XML/JSON tool calls no longer drop N-1 of N calls (#7914)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent ad2205eb
......@@ -770,7 +770,7 @@ impl JailedStream {
async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) {
match &self.jail_mode {
JailMode::MarkerBased => {
// Path 1: End sequence detected
// Path 1: End sequence detected via naive string search.
let end_marker_info = if !self.jail_end_sequences.is_empty() {
self.jail_end_sequences.iter().find_map(|seq| {
accumulated_content
......@@ -784,9 +784,11 @@ impl JailedStream {
// Path 2: Complete tool call(s) can be parsed (early exit)
let early_exit = self.should_exit_jail_early(accumulated_content).await;
if let Some((end_pos, _)) = end_marker_info {
(true, end_pos)
} else if early_exit {
// When a tool_call_parser is active, prefer Path 2 over Path 1 so
// that `find_tool_call_end_position` advances past all consecutive
// parallel tool calls instead of splitting at the first end tag.
// Fall back to Path 1 when parsing fails (e.g. malformed content).
if early_exit {
// For early exit, find where the complete tool call ends
if let Some(parser) = &self.tool_call_parser {
let tools_slice = self.tool_definitions.as_deref();
......@@ -806,6 +808,8 @@ impl JailedStream {
} else {
(false, accumulated_content.len())
}
} else if let Some((end_pos, _)) = end_marker_info {
(true, end_pos)
} else {
(false, accumulated_content.len())
}
......
......@@ -2590,6 +2590,96 @@ mod parallel_jail_tests {
validate_parallel_streaming_tool_calls(&results, &expected_calls);
}
/// Regression test for issue #6822:
/// Hermes-style parallel tool calls in a single chunk must produce N tool call
/// results, not 1 call + trailing raw XML text.
#[tokio::test]
async fn test_parallel_tool_calls_single_chunk_hermes() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
// Two parallel calls arrive in one streaming chunk (hermes uses JSON inside tags).
let input_chunks = vec![test_utils::create_mock_response_chunk(
"<tool_call>\n\
{\"name\": \"get_current_weather\", \"arguments\": {\"city\": \"Dallas\", \"state\": \"TX\"}}\n\
</tool_call>\n\
<tool_call>\n\
{\"name\": \"get_current_weather\", \"arguments\": {\"city\": \"Orlando\", \"state\": \"FL\"}}\n\
</tool_call>"
.to_string(),
0,
)];
let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert!(!results.is_empty(), "Should have results");
let expected_calls = [
(
"get_current_weather",
json!({"city": "Dallas", "state": "TX"}),
),
(
"get_current_weather",
json!({"city": "Orlando", "state": "FL"}),
),
];
validate_parallel_streaming_tool_calls(&results, &expected_calls);
// Verify that raw XML does not leak as text content (the original bug).
for result in &results {
if let Some(ref data) = result.data {
for choice in &data.inner.choices {
if let Some(ref content) = choice.delta.content {
let text = test_utils::extract_text(content);
assert!(
!text.contains("<tool_call>"),
"Raw XML must not leak as text content, got: {text:?}"
);
}
}
}
}
}
/// Regression test for issue #6822:
/// Qwen3Coder-style parallel tool calls in a single chunk must produce N tool
/// call results (identical format to hermes, different parser name).
#[tokio::test]
async fn test_parallel_tool_calls_single_chunk_qwen3_coder() {
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.build();
let input_chunks = vec![test_utils::create_mock_response_chunk(
"<tool_call>\n\
<function=search>\n\
<parameter=query>Rust async</parameter>\n\
</function>\n\
</tool_call>\n\
<tool_call>\n\
<function=search>\n\
<parameter=query>Python async</parameter>\n\
</function>\n\
</tool_call>"
.to_string(),
0,
)];
let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert!(!results.is_empty(), "Should have results");
let expected_calls = [
("search", json!({"query": "Rust async"})),
("search", json!({"query": "Python async"})),
];
validate_parallel_streaming_tool_calls(&results, &expected_calls);
}
// =============================================================================
// 2. PARALLEL TOOL CALLS ACROSS MULTIPLE CHUNKS (STREAMING)
// =============================================================================
......
......@@ -52,12 +52,32 @@ pub fn find_tool_call_end_position_json(
) -> usize {
match parser {
"hermes" | "nemotron_deci" => {
let start_token = config.tool_call_start_tokens.first().map(|s| s.as_str());
if let Some(end_token) = config.tool_call_end_tokens.first() {
if let Some(pos) = chunk.find(end_token) {
pos + end_token.len()
let Some(first_end) = chunk.find(end_token.as_str()) else {
return chunk.len();
};
let mut cursor = first_end + end_token.len();
// Advance past any additional consecutive start→end blocks
// so that parallel tool calls are captured as one jailed region.
if let Some(start_tok) = start_token {
loop {
let rest = &chunk[cursor..];
let trimmed = rest.trim_start();
if !trimmed.starts_with(start_tok) {
break;
}
let trim_offset = rest.len() - trimmed.len();
let search_from = cursor + trim_offset + start_tok.len();
if let Some(end_pos) = chunk[search_from..].find(end_token.as_str()) {
cursor = search_from + end_pos + end_token.len();
} else {
chunk.len()
break;
}
}
}
cursor
} else {
chunk.len()
}
......@@ -72,3 +92,61 @@ pub fn find_tool_call_end_position_json(
_ => chunk.len(),
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Regression test for issue #6822: parallel tool calls in a single chunk must
/// all be captured by find_tool_call_end_position_json so that the jail passes the
/// entire group to the parser rather than emitting the second (and later) calls
/// as raw trailing text.
#[test]
fn test_find_tool_call_end_position_parallel_calls() {
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
};
// Two parallel calls with no whitespace between them.
let two_calls = concat!(
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"x\": 1}}</tool_call>",
"<tool_call>{\"name\": \"bar\", \"arguments\": {\"y\": 2}}</tool_call>",
"trailing"
);
let pos = find_tool_call_end_position_json(two_calls, "hermes", &config);
assert!(
two_calls[..pos].ends_with("</tool_call>"),
"should end at last </tool_call>, got: {:?}",
&two_calls[..pos]
);
assert_eq!(&two_calls[pos..], "trailing");
// Three parallel calls separated by whitespace / newlines.
let three_calls = concat!(
"<tool_call>{\"name\": \"a\"}</tool_call>\n",
"<tool_call>{\"name\": \"b\"}</tool_call>\n",
"<tool_call>{\"name\": \"c\"}</tool_call> done"
);
let pos3 = find_tool_call_end_position_json(three_calls, "hermes", &config);
assert!(
three_calls[..pos3].ends_with("</tool_call>"),
"should end at last </tool_call>, got: {:?}",
&three_calls[..pos3]
);
assert_eq!(three_calls[pos3..].trim(), "done");
// Incomplete second call — should stop after the first complete one.
let incomplete = concat!(
"<tool_call>{\"name\": \"a\"}</tool_call>",
"<tool_call>{\"name\": \"b\""
);
let pos_inc = find_tool_call_end_position_json(incomplete, "hermes", &config);
let first_end = "<tool_call>{\"name\": \"a\"}</tool_call>".len();
assert_eq!(
pos_inc, first_end,
"should stop at end of first complete call when second is incomplete"
);
}
}
......@@ -47,16 +47,43 @@ pub fn detect_tool_call_start_xml(chunk: &str, config: &XmlParserConfig) -> bool
false
}
/// Find the end position of a Qwen3Coder tool call.
/// Returns the position after </tool_call> or the length of the chunk if not found.
/// Find the end position of all consecutive XML-style tool calls.
/// When a model emits multiple parallel tool calls in one chunk
/// (e.g. `<tool_call>...</tool_call><tool_call>...</tool_call>`), this function
/// advances past every consecutive start→end pair so the entire group is captured
/// as a single jailed region. Returns the position after the last `</tool_call>`
/// found, or the length of the chunk when no end token is present.
pub fn find_tool_call_end_position_xml(chunk: &str, config: &XmlParserConfig) -> usize {
let start_token = &config.tool_call_start_token;
let end_token = &config.tool_call_end_token;
if let Some(pos) = chunk.find(end_token.as_str()) {
pos + end_token.len()
// Find the first end token — if there isn't one, the call is incomplete.
let Some(first_end) = chunk.find(end_token.as_str()) else {
return chunk.len();
};
let mut cursor = first_end + end_token.len();
// Keep consuming additional consecutive <tool_call>…</tool_call> blocks that
// follow immediately (possibly separated by whitespace).
loop {
let rest = &chunk[cursor..];
let trimmed = rest.trim_start();
if !trimmed.starts_with(start_token.as_str()) {
break;
}
// Compute where the trimmed slice starts in the original chunk.
let trim_offset = rest.len() - trimmed.len();
let search_from = cursor + trim_offset + start_token.len();
if let Some(end_pos) = chunk[search_from..].find(end_token.as_str()) {
cursor = search_from + end_pos + end_token.len();
} else {
chunk.len()
// Next block is incomplete — stop here; the jail will wait for more data.
break;
}
}
cursor
}
/// Try to parse Qwen3Coder formatted tool calls from a message.
......@@ -586,6 +613,51 @@ mod tests {
assert_eq!(pos, text_no_end.len());
}
/// Regression test for issue #6822: parallel tool calls in a single chunk must
/// all be captured by find_tool_call_end_position_xml so that the jail passes the
/// entire group to extract_tool_calls rather than emitting the second (and later)
/// calls as raw trailing text.
#[test]
fn test_find_tool_call_end_position_parallel_calls() {
let config = XmlParserConfig::default();
// Two parallel calls with no whitespace between them.
let two_calls = "<tool_call><function=foo><parameter=x>1</parameter></function></tool_call>\
<tool_call><function=bar><parameter=y>2</parameter></function></tool_call>\
trailing";
let pos = find_tool_call_end_position_xml(two_calls, &config);
// Everything up to (but not including) "trailing" should be captured.
assert!(
&two_calls[..pos].ends_with("</tool_call>"),
"should end at last </tool_call>, got: {:?}",
&two_calls[..pos]
);
assert_eq!(&two_calls[pos..], "trailing");
// Three parallel calls separated by whitespace / newlines.
let three_calls = "<tool_call><function=a></function></tool_call>\n\
<tool_call><function=b></function></tool_call>\n\
<tool_call><function=c></function></tool_call> done";
let pos3 = find_tool_call_end_position_xml(three_calls, &config);
assert!(
&three_calls[..pos3].ends_with("</tool_call>"),
"should end at last </tool_call>, got: {:?}",
&three_calls[..pos3]
);
assert_eq!(three_calls[pos3..].trim(), "done");
// Incomplete second call — should stop after the first complete one.
let incomplete = "<tool_call><function=a></function></tool_call>\
<tool_call><function=b>"; // no </tool_call>
let pos_inc = find_tool_call_end_position_xml(incomplete, &config);
// The first complete call ends at position 46 (length of first block).
let first_end = "<tool_call><function=a></function></tool_call>".len();
assert_eq!(
pos_inc, first_end,
"should stop at end of first complete call when second is incomplete"
);
}
#[rstest]
#[case(r#"{"key": "value"}"#, serde_json::json!({"key": "value"}), "JSON object")]
#[case(r#"[1, 2, 3]"#, serde_json::json!([1, 2, 3]), "JSON array")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment