Unverified Commit 07af8dca authored by Asad Shahid's avatar Asad Shahid Committed by GitHub
Browse files

fix: tool_choice=required bypasses format-specific parsers (#6821) (#7589)

parent f46498f2
...@@ -1108,14 +1108,35 @@ impl OpenAIPreprocessor { ...@@ -1108,14 +1108,35 @@ impl OpenAIPreprocessor {
} }
// Configure jail based on tool_choice // Configure jail based on tool_choice
//
// When a tool_call_parser is configured, always use marker-based mode
// so that format-specific parsers (e.g. qwen3_coder XML) are invoked.
// Immediate JSON mode is only a fallback for required/named when no
// parser exists (the model is expected to emit raw JSON in that case).
match tool_choice { match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => { Some(ChatCompletionToolChoiceOption::Named(named)) => {
// Immediate jail mode for named tool choice if let Some(parser) = tool_call_parser {
builder = builder.tool_choice_named(named.function.name.clone()); // Parser-aware path: use marker-based jail so the parser
// handles format-specific output (XML, pythonic, etc.).
// Also install a named-tool filter so that if the model emits
// the wrong tool, the parsed call is rejected before emission.
builder = builder
.tool_call_parser(parser)
.named_tool_filter(named.function.name.clone());
} else {
// No parser: fall back to Immediate JSON jail mode.
builder = builder.tool_choice_named(named.function.name.clone());
}
} }
Some(ChatCompletionToolChoiceOption::Required) => { Some(ChatCompletionToolChoiceOption::Required) => {
// Immediate jail mode for required tool choice if let Some(parser) = tool_call_parser {
builder = builder.tool_choice_required(); // Parser-aware path: use marker-based jail so the parser
// handles format-specific output (XML, pythonic, etc.).
builder = builder.tool_call_parser(parser);
} else {
// No parser: fall back to Immediate JSON jail mode.
builder = builder.tool_choice_required();
}
} }
Some(ChatCompletionToolChoiceOption::Auto) Some(ChatCompletionToolChoiceOption::Auto)
| Some(ChatCompletionToolChoiceOption::None) | Some(ChatCompletionToolChoiceOption::None)
......
...@@ -470,6 +470,9 @@ pub struct JailedStream { ...@@ -470,6 +470,9 @@ pub struct JailedStream {
jail_start_sequences: Vec<String>, jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>, jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
/// When set, only tool calls with this name are emitted (enforces tool_choice=named
/// when a tool_call_parser is active and the parser-aware MarkerBased path is used).
named_tool_name: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>, tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode, emission_mode: EmissionMode,
marker_matcher: MarkerMatcher, marker_matcher: MarkerMatcher,
...@@ -492,8 +495,9 @@ impl JailedStream { ...@@ -492,8 +495,9 @@ impl JailedStream {
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static, S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{ {
let jail_mode = self.jail_mode.clone(); let jail_mode = self.jail_mode.clone();
let named_tool_active = self.named_tool_name.is_some();
let jailed_stream = self.apply(stream); let jailed_stream = self.apply(stream);
JailedStream::fix_finish_reason(jailed_stream, jail_mode) JailedStream::fix_finish_reason(jailed_stream, jail_mode, named_tool_active)
} }
/// Apply the jail transformation to a stream of chat completion responses /// Apply the jail transformation to a stream of chat completion responses
...@@ -856,6 +860,37 @@ impl JailedStream { ...@@ -856,6 +860,37 @@ impl JailedStream {
if let Ok((tool_calls, normal_text)) = parse_result if let Ok((tool_calls, normal_text)) = parse_result
&& !tool_calls.is_empty() && !tool_calls.is_empty()
{ {
// If a named tool filter is set (tool_choice=named + parser path), reject
// tool calls that don't match the required tool name.
let tool_calls = if let Some(ref required_name) = self.named_tool_name {
let filtered: Vec<_> = tool_calls
.into_iter()
.filter(|tc| tc.function.name == *required_name)
.collect();
if filtered.is_empty() {
tracing::warn!(
required = %required_name,
"tool_choice=named: parser emitted no matching tool calls; dropping jail output"
);
}
filtered
} else {
tool_calls
};
if tool_calls.is_empty() {
// All parsed calls were for the wrong tool — return content choice
return create_choice_stream(
choice_index,
Some(Role::Assistant),
accumulated_content,
None,
base_choice.finish_reason,
base_choice.stop_reason.clone(),
base_choice.logprobs.clone(),
);
}
// Convert to streaming format // Convert to streaming format
let tool_call_chunks: Vec<ChatCompletionMessageToolCallChunk> = tool_calls let tool_call_chunks: Vec<ChatCompletionMessageToolCallChunk> = tool_calls
.into_iter() .into_iter()
...@@ -1004,6 +1039,7 @@ impl JailedStream { ...@@ -1004,6 +1039,7 @@ impl JailedStream {
fn fix_finish_reason<S>( fn fix_finish_reason<S>(
input_stream: S, input_stream: S,
jail_mode: JailMode, jail_mode: JailMode,
named_tool_active: bool,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static, S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
...@@ -1032,10 +1068,10 @@ impl JailedStream { ...@@ -1032,10 +1068,10 @@ impl JailedStream {
match &jail_mode { match &jail_mode {
JailMode::MarkerBased => { JailMode::MarkerBased => {
// Traditional: if tool calls emitted, change to ToolCalls if has_tool_calls && !named_tool_active {
if has_tool_calls {
choice.finish_reason = Some(FinishReason::ToolCalls); choice.finish_reason = Some(FinishReason::ToolCalls);
} }
// When named_tool_active, keep Stop (OpenAI spec for tool_choice=named)
} }
JailMode::Immediate { format } => { JailMode::Immediate { format } => {
// tool_choice mode: apply specific finish_reason logic // tool_choice mode: apply specific finish_reason logic
...@@ -1070,6 +1106,9 @@ pub struct JailedStreamBuilder { ...@@ -1070,6 +1106,9 @@ pub struct JailedStreamBuilder {
jail_start_sequences: Vec<String>, jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>, jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
/// When set, only tool calls with this name are emitted (enforces tool_choice=named
/// when a tool_call_parser is active and the parser-aware MarkerBased path is used).
named_tool_name: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>, tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode, emission_mode: EmissionMode,
jail_mode: JailMode, jail_mode: JailMode,
...@@ -1082,6 +1121,7 @@ impl JailedStreamBuilder { ...@@ -1082,6 +1121,7 @@ impl JailedStreamBuilder {
jail_start_sequences: Vec::new(), jail_start_sequences: Vec::new(),
jail_end_sequences: Vec::new(), jail_end_sequences: Vec::new(),
tool_call_parser: None, tool_call_parser: None,
named_tool_name: None,
tool_definitions: None, tool_definitions: None,
emission_mode: EmissionMode::default(), emission_mode: EmissionMode::default(),
jail_mode: JailMode::MarkerBased, jail_mode: JailMode::MarkerBased,
...@@ -1126,6 +1166,14 @@ impl JailedStreamBuilder { ...@@ -1126,6 +1166,14 @@ impl JailedStreamBuilder {
self self
} }
/// Constrain parsed output to a single named tool (for tool_choice=named + parser path).
/// When set, tool calls emitted by the parser that don't match `tool_name` are silently
/// filtered out, enforcing the named-tool contract even when the model emits the wrong tool.
pub fn named_tool_filter(mut self, tool_name: impl Into<String>) -> Self {
self.named_tool_name = Some(tool_name.into());
self
}
/// Set the tool definitions for runtime validation and parsing /// Set the tool definitions for runtime validation and parsing
pub fn tool_definitions( pub fn tool_definitions(
mut self, mut self,
...@@ -1245,6 +1293,7 @@ impl JailedStreamBuilder { ...@@ -1245,6 +1293,7 @@ impl JailedStreamBuilder {
jail_start_sequences: self.jail_start_sequences, jail_start_sequences: self.jail_start_sequences,
jail_end_sequences: self.jail_end_sequences, jail_end_sequences: self.jail_end_sequences,
tool_call_parser: self.tool_call_parser, tool_call_parser: self.tool_call_parser,
named_tool_name: self.named_tool_name,
tool_definitions: self.tool_definitions, tool_definitions: self.tool_definitions,
emission_mode: self.emission_mode, emission_mode: self.emission_mode,
marker_matcher, marker_matcher,
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_protocols::types::{ use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, CompletionUsage, FinishReason, Role, ChatChoiceStream, ChatCompletionStreamResponseDelta, ChatCompletionToolChoiceOption,
CompletionUsage, FinishReason, Role,
}; };
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
...@@ -3080,4 +3082,191 @@ mod parallel_jail_tests { ...@@ -3080,4 +3082,191 @@ mod parallel_jail_tests {
"Should have no tool calls for empty array" "Should have no tool calls for empty array"
); );
} }
/// Regression test for #6821: tool_choice=required with qwen3_coder parser.
///
/// When tool_choice=required AND a tool_call_parser (e.g. qwen3_coder) is
/// configured, the jail must use marker-based mode so the parser handles the
/// XML output. Previously this fell through to Immediate JSON mode which
/// could not parse qwen3_coder XML, causing raw XML to leak as content.
#[tokio::test]
async fn test_tool_choice_required_with_qwen3_coder_parser() {
// Simulate qwen3_coder XML output for a single tool call
let xml_output = r#"<tool_call>
<function=get_weather>
<parameter=city>
San Francisco
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;
let input_chunks = vec![test_utils::create_mock_response_chunk(
xml_output.to_string(),
0,
)];
let input_stream = stream::iter(input_chunks);
let results: Vec<_> = OpenAIPreprocessor::apply_tool_calling_jail(
Some("qwen3_coder".to_string()),
Some(ChatCompletionToolChoiceOption::Required),
None,
input_stream,
)
.collect()
.await;
// Should have parsed a tool call, not leaked raw XML as content
let tool_call_count: usize = results
.iter()
.map(|r| {
r.data.as_ref().map_or(0, |d| {
d.inner
.choices
.iter()
.map(|c: &ChatChoiceStream| {
c.delta.tool_calls.as_ref().map_or(0, |tc| tc.len())
})
.sum::<usize>()
})
})
.sum();
assert!(
tool_call_count >= 1,
"tool_choice=required with qwen3_coder should produce at least one tool call, got {}",
tool_call_count
);
// Verify the tool call was parsed correctly
for r in &results {
if let Some(data) = &r.data {
for choice in &data.inner.choices {
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
assert_eq!(
tc.function.as_ref().unwrap().name.as_deref(),
Some("get_weather"),
"Tool call name should be 'get_weather'"
);
}
}
// Content should be empty, not raw XML
if let Some(content) = &choice.delta.content {
let text = test_utils::extract_text(content);
assert!(
!text.contains("<tool_call>"),
"Raw XML should not leak as content, got: {}",
text
);
}
}
}
}
}
/// Test for tool_choice=named with qwen3_coder parser and named_tool_filter.
///
/// When tool_choice=named is used with a specific tool_name, the
/// preprocessor decision logic should apply the named_tool_filter to ensure
/// only the requested tool is parsed, even if the model emits other tools.
#[tokio::test]
async fn test_tool_choice_named_with_qwen3_coder_parser() {
// Simulate qwen3_coder XML output for a single tool call
let xml_output = r#"<tool_call>
<function=get_weather>
<parameter=city>
San Francisco
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;
let input_chunks = vec![
test_utils::create_mock_response_chunk(xml_output.to_string(), 0),
test_utils::create_final_response_chunk(0),
];
let input_stream = stream::iter(input_chunks);
// Apply tool_choice=named for get_weather
let results: Vec<_> = OpenAIPreprocessor::apply_tool_calling_jail(
Some("qwen3_coder".to_string()),
Some(ChatCompletionToolChoiceOption::Named(
"get_weather".to_string().into(),
)),
None,
input_stream,
)
.collect()
.await;
// Should have parsed the named tool call
let tool_call_count: usize = results
.iter()
.map(|r| {
r.data.as_ref().map_or(0, |d| {
d.inner
.choices
.iter()
.map(|c: &ChatChoiceStream| {
c.delta.tool_calls.as_ref().map_or(0, |tc| tc.len())
})
.sum::<usize>()
})
})
.sum();
assert!(
tool_call_count >= 1,
"tool_choice=named with qwen3_coder should produce at least one tool call, got {}",
tool_call_count
);
// Verify the tool call was parsed correctly and matches the named tool
for r in &results {
if let Some(data) = &r.data {
for choice in &data.inner.choices {
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
assert_eq!(
tc.function.as_ref().unwrap().name.as_deref(),
Some("get_weather"),
"Tool call name should match the named tool choice"
);
}
}
// Content should be empty, not raw XML
if let Some(content) = &choice.delta.content {
let text = test_utils::extract_text(content);
assert!(
!text.contains("<tool_call>"),
"Raw XML should not leak as content, got: {}",
text
);
}
}
}
}
// Verify finish_reason is Stop (not ToolCalls) for named tool choice
let finish_reasons: Vec<_> = results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.inner.choices.first().and_then(|c| c.finish_reason))
})
.collect();
// For tool_choice=named, finish_reason should be Stop (OpenAI spec)
assert!(
finish_reasons.contains(&FinishReason::Stop),
"tool_choice=named should have Stop finish reason"
);
}
} }
...@@ -454,3 +454,154 @@ fn test_no_tool_choice_outputs_normal_text() { ...@@ -454,3 +454,154 @@ fn test_no_tool_choice_outputs_normal_text() {
); );
assert!(response.inner.choices[0].delta.tool_calls.is_none()); assert!(response.inner.choices[0].delta.tool_calls.is_none());
} }
// ---------------------------------------------------------------------------
// tool_choice=named + tool_call_parser enforcement (CodeRabbit PR #7589)
// ---------------------------------------------------------------------------
/// Build a raw streaming response chunk with arbitrary text content.
fn make_text_chunk(
text: &str,
finish: bool,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionStreamResponseDelta, Role,
};
#[allow(deprecated)]
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test-named-parser".to_string(),
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(ChatCompletionMessageContent::Text(text.to_string())),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: if finish {
Some(dynamo_protocols::types::FinishReason::Stop)
} else {
None
},
stop_reason: None,
logprobs: None,
}],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
},
nvext: None,
}
}
/// Apply jail with both a tool_call_parser and a named_tool_filter, returning all chunks.
async fn apply_jail_named_with_parser(
chunks: Vec<
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
>,
parser: &str,
named_tool: &str,
) -> Vec<dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse> {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input = stream::iter(chunks.into_iter().map(|r| Annotated {
data: Some(r),
id: None,
event: None,
comment: None,
error: None,
}));
let jail = JailedStream::builder()
.tool_call_parser(parser)
.named_tool_filter(named_tool)
.build();
let out = jail.apply_with_finish_reason(input);
tokio::pin!(out);
out.filter_map(|ann| async move { ann.data })
.collect()
.await
}
/// When tool_choice=named, a tool_call_parser is configured, and the model emits
/// the **correct** tool, the parsed tool call must pass through with the right name.
#[tokio::test]
async fn test_named_tool_with_parser_correct_tool_passes() {
// Hermes format: <tool_call>{"name":"get_weather","arguments":{...}}\n</tool_call>
let hermes_payload = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}\n</tool_call>";
let chunks = vec![
make_text_chunk(hermes_payload, false),
make_text_chunk("", true), // final empty chunk with finish_reason
];
let responses = apply_jail_named_with_parser(chunks, "hermes", "get_weather").await;
// Should have at least one response with tool calls
let tool_call_response = responses
.iter()
.find(|r| {
r.inner
.choices
.first()
.and_then(|c| c.delta.tool_calls.as_ref())
.is_some()
})
.expect("expected a response with tool calls for the correct named tool");
let tool_calls = tool_call_response.inner.choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1, "expected exactly one tool call");
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather"),
"tool call name should be get_weather"
);
}
/// When tool_choice=named, a tool_call_parser is configured, and the model emits
/// the **wrong** tool, the parsed call must be filtered out (not emitted).
/// Regression test for CodeRabbit review on PR #7589.
#[tokio::test]
async fn test_named_tool_with_parser_wrong_tool_is_filtered() {
// Model emits "search" but we required "get_weather"
let hermes_wrong_tool = "<tool_call>\n{\"name\": \"search\", \"arguments\": {\"query\": \"Paris weather\"}}\n</tool_call>";
let chunks = vec![
make_text_chunk(hermes_wrong_tool, false),
make_text_chunk("", true),
];
let responses = apply_jail_named_with_parser(chunks, "hermes", "get_weather").await;
// No response should contain a tool call for the wrong tool
for r in &responses {
if let Some(choice) = r.inner.choices.first()
&& let Some(tool_calls) = &choice.delta.tool_calls
{
for tc in tool_calls {
let name = tc
.function
.as_ref()
.and_then(|f| f.name.as_deref())
.unwrap_or("");
assert_ne!(
name, "search",
"wrong tool 'search' should have been filtered by named_tool_filter"
);
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment