"components/vscode:/vscode.git/clone" did not exist on "8bb9a555a4f7ab2dd2d85d171c7f21bdaf9d5793"
Unverified Commit 5585f803 authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

feat: add tool_choice support (#4722)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
parent 94d145a9
...@@ -752,25 +752,17 @@ impl OpenAIPreprocessor { ...@@ -752,25 +752,17 @@ impl OpenAIPreprocessor {
has_tools: bool, has_tools: bool,
) -> std::result::Result<bool, Error> { ) -> std::result::Result<bool, Error> {
match (tool_call_parser, tool_choice, has_tools) { match (tool_call_parser, tool_choice, has_tools) {
// No parser but tools requested - error cases // tool_choice=required/named work without parser (use Immediate jail mode)
(None, Some(ChatCompletionToolChoiceOption::Required), true) => { (None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true),
tracing::warn!( (None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true),
"Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
); // tool_choice=auto requires a parser
Ok(false)
}
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => { (None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
tracing::warn!( tracing::warn!(
"Tool choice 'auto' specified but no tool parser configured; proceeding without jailing" "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
); );
Ok(false) Ok(false)
} }
(None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
tracing::warn!(
"Named tool choice specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
}
// Parser exists and tools might be called // Parser exists and tools might be called
(Some(_), Some(ChatCompletionToolChoiceOption::None), _) => { (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
...@@ -786,15 +778,38 @@ impl OpenAIPreprocessor { ...@@ -786,15 +778,38 @@ impl OpenAIPreprocessor {
/// Apply tool calling jail to the stream if needed /// Apply tool calling jail to the stream if needed
pub fn apply_tool_calling_jail<S>( pub fn apply_tool_calling_jail<S>(
tool_call_parser: String, tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
stream: S, stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static, S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{ {
let jail = JailedStream::builder() use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
.tool_call_parser(tool_call_parser)
.build(); let mut builder = JailedStream::builder();
// Configure jail based on tool_choice
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
// Immediate jail mode for named tool choice
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
// Immediate jail mode for required tool choice
builder = builder.tool_choice_required();
}
Some(ChatCompletionToolChoiceOption::Auto)
| Some(ChatCompletionToolChoiceOption::None)
| None => {
// Traditional marker-based jail for auto/none/unspecified
if let Some(parser) = tool_call_parser {
builder = builder.tool_call_parser(parser);
}
}
}
let jail = builder.build();
jail.apply_with_finish_reason(stream) jail.apply_with_finish_reason(stream)
} }
...@@ -957,11 +972,11 @@ impl ...@@ -957,11 +972,11 @@ impl
// Apply jail conditionally // Apply jail conditionally
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail { let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
if let Some(parser) = self.tool_call_parser.clone() { Box::pin(Self::apply_tool_calling_jail(
Box::pin(Self::apply_tool_calling_jail(parser, stream)) self.tool_call_parser.clone(),
} else { request.inner.tool_choice.clone(),
Box::pin(stream) // Should not happen due to should_jail check stream,
} ))
} else { } else {
Box::pin(stream) Box::pin(stream)
}; };
......
...@@ -17,6 +17,7 @@ pub mod embeddings; ...@@ -17,6 +17,7 @@ pub mod embeddings;
pub mod models; pub mod models;
pub mod nvext; pub mod nvext;
pub mod responses; pub mod responses;
pub mod tools;
pub mod validate; pub mod validate;
use validate::{ use validate::{
...@@ -131,7 +132,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -131,7 +132,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let guided_whitespace_pattern = self.get_guided_whitespace_pattern(); let guided_whitespace_pattern = self.get_guided_whitespace_pattern();
let guided_decoding = match common::GuidedDecodingOptions::from_optional( let guided_decoding = match common::GuidedDecodingOptions::from_optional(
guided_json.cloned(), guided_json,
guided_regex, guided_regex,
guided_choice, guided_choice,
guided_grammar, guided_grammar,
......
...@@ -12,7 +12,7 @@ use super::{ ...@@ -12,7 +12,7 @@ use super::{
common_ext::{CommonExt, CommonExtProvider}, common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt, nvext::NvExt,
nvext::NvExtProvider, nvext::NvExtProvider,
validate, tools, validate,
}; };
pub mod aggregator; pub mod aggregator;
...@@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
} }
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> { fn get_guided_json(&self) -> Option<serde_json::Value> {
self.common.guided_json.as_ref() if let Some(value) = self.common.guided_json.clone() {
return Some(value);
}
let tool_choice = self.inner.tool_choice.as_ref()?;
let tools = self.inner.tools.as_deref()?;
match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) {
Ok(schema) => schema,
Err(err) => {
tracing::warn!(
error = %err,
"failed to derive guided_json from tool_choice"
);
None
}
}
} }
fn get_guided_regex(&self) -> Option<String> { fn get_guided_regex(&self) -> Option<String> {
......
...@@ -14,6 +14,7 @@ use dynamo_parsers::tool_calling::{ ...@@ -14,6 +14,7 @@ use dynamo_parsers::tool_calling::{
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use uuid::Uuid;
use crate::utils::{MarkerMatcher, MatchResult}; use crate::utils::{MarkerMatcher, MatchResult};
...@@ -62,6 +63,24 @@ pub struct JailConfig<'a> { ...@@ -62,6 +63,24 @@ pub struct JailConfig<'a> {
pub tool_call_parser: Option<&'a str>, pub tool_call_parser: Option<&'a str>,
} }
/// Jail activation mode
#[derive(Debug, Clone, PartialEq)]
pub enum JailMode {
/// Traditional: wait for start marker, then jail
MarkerBased,
/// Immediate: start jailed from first token (for tool_choice)
Immediate { format: ToolChoiceFormat },
}
/// Format for tool_choice immediate jail mode
#[derive(Debug, Clone, PartialEq)]
pub enum ToolChoiceFormat {
/// tool_choice=named: expect single object {"location": "Paris", ...}
SingleObject { tool_name: String },
/// tool_choice=required: expect array [{name:"search", parameters:{...}}, ...]
ArrayOfTools,
}
/// State tracking for an individual choice during jail processing /// State tracking for an individual choice during jail processing
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ChoiceJailState { struct ChoiceJailState {
...@@ -105,10 +124,10 @@ fn create_choice_stream( ...@@ -105,10 +124,10 @@ fn create_choice_stream(
impl ChoiceJailState { impl ChoiceJailState {
/// Create a new jail state for a choice /// Create a new jail state for a choice
fn new(index: u32) -> Self { fn new(index: u32, starts_jailed: bool) -> Self {
Self { Self {
index, index,
is_jailed: false, is_jailed: starts_jailed,
accumulated_content: String::new(), accumulated_content: String::new(),
partial_match_buffer: String::new(), partial_match_buffer: String::new(),
stream_finish_reason: None, stream_finish_reason: None,
...@@ -409,7 +428,7 @@ impl ChoiceJailStateCollection { ...@@ -409,7 +428,7 @@ impl ChoiceJailStateCollection {
} }
/// Get or create state for a choice index /// Get or create state for a choice index
fn get_or_create_state(&mut self, index: u32) -> &mut ChoiceJailState { fn get_or_create_state(&mut self, index: u32, starts_jailed: bool) -> &mut ChoiceJailState {
// Find the position where this index should be // Find the position where this index should be
match self.states.binary_search_by_key(&index, |s| s.index) { match self.states.binary_search_by_key(&index, |s| s.index) {
Ok(pos) => { Ok(pos) => {
...@@ -418,7 +437,7 @@ impl ChoiceJailStateCollection { ...@@ -418,7 +437,7 @@ impl ChoiceJailStateCollection {
} }
Err(insert_pos) => { Err(insert_pos) => {
// Need to create new state // Need to create new state
let new_state = ChoiceJailState::new(index); let new_state = ChoiceJailState::new(index, starts_jailed);
self.states.insert(insert_pos, new_state); self.states.insert(insert_pos, new_state);
&mut self.states[insert_pos] &mut self.states[insert_pos]
} }
...@@ -427,20 +446,15 @@ impl ChoiceJailStateCollection { ...@@ -427,20 +446,15 @@ impl ChoiceJailStateCollection {
} }
/// Emission mode for handling multiple choices /// Emission mode for handling multiple choices
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EmissionMode { pub enum EmissionMode {
/// Pack multiple choices in the same chunk (default, matches original behavior) /// Pack multiple choices in the same chunk (default, matches original behavior)
#[default]
Packed, Packed,
/// Emit one choice per chunk for OpenAI compatibility /// Emit one choice per chunk for OpenAI compatibility
SingleChoicePerChunk, SingleChoicePerChunk,
} }
impl Default for EmissionMode {
fn default() -> Self {
Self::Packed
}
}
/// A stream transformer that can "jail" tokens based on configurable start/end sequences /// A stream transformer that can "jail" tokens based on configurable start/end sequences
/// When jailed, tokens are accumulated rather than yielded immediately /// When jailed, tokens are accumulated rather than yielded immediately
/// When the jail ends (via end sequence or stream completion), accumulated content is processed and released /// When the jail ends (via end sequence or stream completion), accumulated content is processed and released
...@@ -450,6 +464,7 @@ pub struct JailedStream { ...@@ -450,6 +464,7 @@ pub struct JailedStream {
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
emission_mode: EmissionMode, emission_mode: EmissionMode,
marker_matcher: MarkerMatcher, marker_matcher: MarkerMatcher,
jail_mode: JailMode,
} }
impl JailedStream { impl JailedStream {
...@@ -467,8 +482,9 @@ impl JailedStream { ...@@ -467,8 +482,9 @@ impl JailedStream {
where where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static, S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{ {
let jail_mode = self.jail_mode.clone();
let jailed_stream = self.apply(stream); let jailed_stream = self.apply(stream);
JailedStream::fix_finish_reason(jailed_stream) JailedStream::fix_finish_reason(jailed_stream, jail_mode)
} }
/// Apply the jail transformation to a stream of chat completion responses /// Apply the jail transformation to a stream of chat completion responses
...@@ -508,7 +524,8 @@ impl JailedStream { ...@@ -508,7 +524,8 @@ impl JailedStream {
// Process each choice independently using the new architecture // Process each choice independently using the new architecture
for choice in &chat_response.choices { for choice in &chat_response.choices {
if let Some(ref content) = choice.delta.content { if let Some(ref content) = choice.delta.content {
let choice_state = choice_states.get_or_create_state(choice.index); let starts_jailed = matches!(self.jail_mode, JailMode::Immediate { .. });
let choice_state = choice_states.get_or_create_state(choice.index, starts_jailed);
// Store metadata when any choice becomes jailed (first time only) // Store metadata when any choice becomes jailed (first time only)
if !choice_state.is_jailed && self.should_start_jail(content) if !choice_state.is_jailed && self.should_start_jail(content)
...@@ -526,7 +543,16 @@ impl JailedStream { ...@@ -526,7 +543,16 @@ impl JailedStream {
all_emissions.extend(emissions); all_emissions.extend(emissions);
} else { } else {
// Handle choices without content (e.g., final chunks with finish_reason) // Handle choices without content (e.g., final chunks with finish_reason)
// These should always pass through // Only filter out if this choice was ever jailed and lacks role
// (to avoid aggregator issues with deltas missing role after unjail)
let choice_state = choice_states.get_or_create_state(choice.index, false);
let was_ever_jailed = !choice_state.accumulated_content.is_empty() || choice_state.is_jailed;
let should_emit = choice.delta.role.is_some()
|| choice.delta.tool_calls.is_some()
|| !was_ever_jailed; // Always pass through if never jailed
if should_emit {
let pass_through_choice = ChatChoiceStream { let pass_through_choice = ChatChoiceStream {
index: choice.index, index: choice.index,
delta: choice.delta.clone(), delta: choice.delta.clone(),
...@@ -536,6 +562,7 @@ impl JailedStream { ...@@ -536,6 +562,7 @@ impl JailedStream {
all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
} }
} }
}
// Emit all results based on emission mode // Emit all results based on emission mode
if !all_emissions.is_empty() { if !all_emissions.is_empty() {
...@@ -701,6 +728,8 @@ impl JailedStream { ...@@ -701,6 +728,8 @@ impl JailedStream {
/// Check if accumulated content should end jail /// Check if accumulated content should end jail
async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) { async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) {
match &self.jail_mode {
JailMode::MarkerBased => {
// Path 1: End sequence detected // Path 1: End sequence detected
let end_marker_info = if !self.jail_end_sequences.is_empty() { let end_marker_info = if !self.jail_end_sequences.is_empty() {
self.jail_end_sequences.iter().find_map(|seq| { self.jail_end_sequences.iter().find_map(|seq| {
...@@ -723,7 +752,8 @@ impl JailedStream { ...@@ -723,7 +752,8 @@ impl JailedStream {
if let Ok((_, _)) = if let Ok((_, _)) =
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await
{ {
let split_pos = find_tool_call_end_position(accumulated_content, Some(parser)); let split_pos =
find_tool_call_end_position(accumulated_content, Some(parser));
(true, split_pos) (true, split_pos)
} else { } else {
(false, accumulated_content.len()) (false, accumulated_content.len())
...@@ -735,6 +765,34 @@ impl JailedStream { ...@@ -735,6 +765,34 @@ impl JailedStream {
(false, accumulated_content.len()) (false, accumulated_content.len())
} }
} }
JailMode::Immediate { format } => {
// For tool_choice, check if we have valid complete JSON
match format {
ToolChoiceFormat::SingleObject { .. } => {
// Expect single object: {"location": "Paris", "unit": "celsius"}
if let Ok(value) =
serde_json::from_str::<serde_json::Value>(accumulated_content)
&& value.is_object()
{
return (true, accumulated_content.len());
}
(false, accumulated_content.len())
}
ToolChoiceFormat::ArrayOfTools => {
// Expect array: [{"name":"search","parameters":{...}}, ...]
if let Ok(value) =
serde_json::from_str::<serde_json::Value>(accumulated_content)
&& let Some(arr) = value.as_array()
&& !arr.is_empty()
{
return (true, accumulated_content.len());
}
(false, accumulated_content.len())
}
}
}
}
}
/// Parse tool calls from accumulated content and create choice /// Parse tool calls from accumulated content and create choice
async fn create_tool_call_choice( async fn create_tool_call_choice(
...@@ -744,8 +802,13 @@ impl JailedStream { ...@@ -744,8 +802,13 @@ impl JailedStream {
base_choice: &ChatChoiceStream, base_choice: &ChatChoiceStream,
tool_call_offset: usize, tool_call_offset: usize,
) -> ChatChoiceStream { ) -> ChatChoiceStream {
if let Ok((tool_calls, normal_text)) = match &self.jail_mode {
try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) JailMode::MarkerBased => {
// Traditional marker-based tool call parsing
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_content,
self.tool_call_parser.as_deref(),
)
.await .await
&& !tool_calls.is_empty() && !tool_calls.is_empty()
{ {
...@@ -785,6 +848,91 @@ impl JailedStream { ...@@ -785,6 +848,91 @@ impl JailedStream {
base_choice.logprobs.clone(), base_choice.logprobs.clone(),
) )
} }
JailMode::Immediate { format } => {
// tool_choice mode: parse JSON and convert to tool calls
match self.parse_tool_choice_json(accumulated_content, format) {
Ok(tool_call_chunks) if !tool_call_chunks.is_empty() => create_choice_stream(
choice_index,
Some(Role::Assistant),
"",
Some(tool_call_chunks),
base_choice.finish_reason,
base_choice.logprobs.clone(),
),
Ok(_) | Err(_) => {
// Parsing failed, return as content
create_choice_stream(
choice_index,
Some(Role::Assistant),
accumulated_content,
None,
base_choice.finish_reason,
base_choice.logprobs.clone(),
)
}
}
}
}
}
/// Helper to create a ChatCompletionMessageToolCallChunk
fn create_tool_call_chunk(
index: u32,
name: String,
arguments: String,
) -> ChatCompletionMessageToolCallChunk {
ChatCompletionMessageToolCallChunk {
index,
id: Some(format!("call-{}", Uuid::new_v4())),
r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
name: Some(name),
arguments: Some(arguments),
}),
}
}
/// Parse tool_choice JSON output into tool call chunks
fn parse_tool_choice_json(
&self,
json_content: &str,
format: &ToolChoiceFormat,
) -> anyhow::Result<Vec<ChatCompletionMessageToolCallChunk>> {
let parsed = serde_json::from_str::<serde_json::Value>(json_content)?;
match format {
ToolChoiceFormat::SingleObject { tool_name } => {
// For named tool choice: JSON is the parameters object
if parsed.is_object() {
Ok(vec![Self::create_tool_call_chunk(
0,
tool_name.clone(),
json_content.to_string(),
)])
} else {
Ok(vec![])
}
}
ToolChoiceFormat::ArrayOfTools => {
// For required tool choice: JSON is array of {name, parameters}
if let Some(array) = parsed.as_array() {
let chunks: Vec<ChatCompletionMessageToolCallChunk> = array
.iter()
.enumerate()
.filter_map(|(idx, entry)| {
let name = entry.get("name")?.as_str()?.to_string();
let parameters = entry.get("parameters")?;
let args = serde_json::to_string(parameters).ok()?;
Some(Self::create_tool_call_chunk(idx as u32, name, args))
})
.collect();
Ok(chunks)
} else {
Ok(vec![])
}
}
}
}
/// Check if accumulated content contains complete tool calls that can be parsed /// Check if accumulated content contains complete tool calls that can be parsed
/// Returns true if we should exit the jail early /// Returns true if we should exit the jail early
...@@ -804,8 +952,9 @@ impl JailedStream { ...@@ -804,8 +952,9 @@ impl JailedStream {
/// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted /// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted
/// This should be called after apply() to fix the finish_reason for tool call chunks /// This should be called after apply() to fix the finish_reason for tool call chunks
pub fn fix_finish_reason<S>( fn fix_finish_reason<S>(
input_stream: S, input_stream: S,
jail_mode: JailMode,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static, S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
...@@ -824,16 +973,42 @@ impl JailedStream { ...@@ -824,16 +973,42 @@ impl JailedStream {
} }
} }
// If this chunk has finish_reason and the choice had tool calls, override to ToolCalls // Fix finish_reason based on jail mode and whether tool calls were emitted
if let Some(ref mut data) = response.data { if let Some(ref mut data) = response.data {
for choice in &mut data.choices { for choice in &mut data.choices {
if choice.finish_reason.is_some() && choice.finish_reason == Some(FinishReason::Stop) if let Some(finish) = choice.finish_reason {
&& has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false) // Only modify Stop finish reason, preserve Length/ContentFilter
{ if finish == FinishReason::Stop {
let has_tool_calls = has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false);
match &jail_mode {
JailMode::MarkerBased => {
// Traditional: if tool calls emitted, change to ToolCalls
if has_tool_calls {
choice.finish_reason = Some(FinishReason::ToolCalls);
}
}
JailMode::Immediate { format } => {
// tool_choice mode: apply specific finish_reason logic
match format {
ToolChoiceFormat::SingleObject { .. } => {
// Named tool choice: keep Stop
// (already Stop, no change needed)
}
ToolChoiceFormat::ArrayOfTools => {
// Required tool choice: change to ToolCalls
if has_tool_calls {
choice.finish_reason = Some(FinishReason::ToolCalls); choice.finish_reason = Some(FinishReason::ToolCalls);
} }
} }
} }
}
}
}
// Length and ContentFilter are preserved as-is
}
}
}
yield response; yield response;
} }
...@@ -847,6 +1022,7 @@ pub struct JailedStreamBuilder { ...@@ -847,6 +1022,7 @@ pub struct JailedStreamBuilder {
jail_end_sequences: Vec<String>, jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
emission_mode: EmissionMode, emission_mode: EmissionMode,
jail_mode: JailMode,
} }
impl JailedStreamBuilder { impl JailedStreamBuilder {
...@@ -857,6 +1033,7 @@ impl JailedStreamBuilder { ...@@ -857,6 +1033,7 @@ impl JailedStreamBuilder {
jail_end_sequences: Vec::new(), jail_end_sequences: Vec::new(),
tool_call_parser: None, tool_call_parser: None,
emission_mode: EmissionMode::default(), emission_mode: EmissionMode::default(),
jail_mode: JailMode::MarkerBased,
} }
} }
...@@ -916,6 +1093,22 @@ impl JailedStreamBuilder { ...@@ -916,6 +1093,22 @@ impl JailedStreamBuilder {
self self
} }
/// Enable immediate jail mode for tool_choice=named
pub fn tool_choice_named(mut self, tool_name: String) -> Self {
self.jail_mode = JailMode::Immediate {
format: ToolChoiceFormat::SingleObject { tool_name },
};
self
}
/// Enable immediate jail mode for tool_choice=required
pub fn tool_choice_required(mut self) -> Self {
self.jail_mode = JailMode::Immediate {
format: ToolChoiceFormat::ArrayOfTools,
};
self
}
/// Build the configured JailedStream /// Build the configured JailedStream
pub fn build(mut self) -> JailedStream { pub fn build(mut self) -> JailedStream {
// Auto-populate jail sequences from parser config if not manually configured // Auto-populate jail sequences from parser config if not manually configured
...@@ -994,6 +1187,7 @@ impl JailedStreamBuilder { ...@@ -994,6 +1187,7 @@ impl JailedStreamBuilder {
tool_call_parser: self.tool_call_parser, tool_call_parser: self.tool_call_parser,
emission_mode: self.emission_mode, emission_mode: self.emission_mode,
marker_matcher, marker_matcher,
jail_mode: self.jail_mode,
} }
} }
} }
......
...@@ -94,7 +94,7 @@ pub trait CommonExtProvider { ...@@ -94,7 +94,7 @@ pub trait CommonExtProvider {
fn common_ext(&self) -> Option<&CommonExt>; fn common_ext(&self) -> Option<&CommonExt>;
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value>; fn get_guided_json(&self) -> Option<serde_json::Value>;
fn get_guided_regex(&self) -> Option<String>; fn get_guided_regex(&self) -> Option<String>;
fn get_guided_grammar(&self) -> Option<String>; fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>; fn get_guided_choice(&self) -> Option<Vec<String>>;
......
...@@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest {
} }
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> { fn get_guided_json(&self) -> Option<serde_json::Value> {
self.common.guided_json.as_ref() self.common.guided_json.clone()
} }
fn get_guided_regex(&self) -> Option<String> { fn get_guided_regex(&self) -> Option<String> {
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::BTreeMap;
use dynamo_async_openai::types::{
ChatCompletionTool, ChatCompletionToolChoiceOption, FunctionObject,
};
use serde_json::{Value, json};
use thiserror::Error;
/// Errors that can occur when deriving JSON schemas for tool_choice requests.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ToolChoiceError {
#[error("tool_choice requires a matching `tools` array")]
MissingTools,
#[error("tool `{0}` was not provided in `tools`")]
ToolNotFound(String),
#[error("$defs for tool `{0}` must be an object")]
InvalidDefinitionMap(String),
#[error("duplicate $defs entry `{0}` has conflicting schemas")]
ConflictingDefinition(String),
#[error("tool_choice `required` needs at least one tool definition")]
EmptyTools,
}
/// Builds the JSON schema enforced by Guided Decoding for the given tool_choice/tools pair.
pub fn get_json_schema_from_tools(
tool_choice: Option<&ChatCompletionToolChoiceOption>,
tools: Option<&[ChatCompletionTool]>,
) -> Result<Option<Value>, ToolChoiceError> {
let Some(choice) = tool_choice else {
return Ok(None);
};
match choice {
ChatCompletionToolChoiceOption::None | ChatCompletionToolChoiceOption::Auto => Ok(None),
ChatCompletionToolChoiceOption::Named(named) => {
let tools = tools.ok_or(ToolChoiceError::MissingTools)?;
let tool = find_tool(tools, &named.function.name)
.ok_or_else(|| ToolChoiceError::ToolNotFound(named.function.name.clone()))?;
Ok(Some(clone_parameters(&tool.function)))
}
ChatCompletionToolChoiceOption::Required => {
let tools = tools.ok_or(ToolChoiceError::MissingTools)?;
if tools.is_empty() {
return Err(ToolChoiceError::EmptyTools);
}
build_required_schema(tools).map(Some)
}
}
}
fn find_tool<'a>(tools: &'a [ChatCompletionTool], name: &str) -> Option<&'a ChatCompletionTool> {
tools.iter().find(|tool| tool.function.name == name)
}
fn clone_parameters(function: &FunctionObject) -> Value {
function
.parameters
.clone()
.unwrap_or_else(|| json!({"type": "object", "properties": {}}))
}
/// Builds a JSON Schema for `tool_choice=required` that enforces an array of tool calls.
///
/// # Schema Structure
///
/// The generated schema looks like:
/// ```json
/// {
/// "type": "array",
/// "minItems": 1,
/// "items": {
/// "type": "object",
/// "anyOf": [
/// {
/// "properties": {
/// "name": {"type": "string", "enum": ["tool1"]},
/// "parameters": { /* tool1's parameter schema */ }
/// },
/// "required": ["name", "parameters"]
/// },
/// {
/// "properties": {
/// "name": {"type": "string", "enum": ["tool2"]},
/// "parameters": { /* tool2's parameter schema */ }
/// },
/// "required": ["name", "parameters"]
/// }
/// ]
/// },
/// "$defs": { /* shared type definitions from all tools */ }
/// }
/// ```
///
/// # $defs Handling
///
/// `$defs` contains shared JSON Schema definitions that can be referenced via `$ref`.
/// For example, if two tools reference a common type:
/// ```json
/// {
/// "$defs": {
/// "Location": {
/// "type": "object",
/// "properties": {
/// "city": {"type": "string"},
/// "country": {"type": "string"}
/// }
/// }
/// }
/// }
/// ```
///
/// We extract `$defs` from each tool's schema and merge them into a global `$defs` map
/// at the root level. If multiple tools define the same type, we verify they match to
/// avoid conflicts.
fn build_required_schema(tools: &[ChatCompletionTool]) -> Result<Value, ToolChoiceError> {
// Accumulator for all shared type definitions ($defs) across tools
let mut defs: BTreeMap<String, Value> = BTreeMap::new();
let mut any_of = Vec::with_capacity(tools.len());
for tool in tools {
// Extract parameter schema and its $defs (if any)
let ParamsAndDefs {
schema,
defs: new_defs,
} = split_defs(&tool.function)?;
merge_defs(&mut defs, new_defs)?;
any_of.push(json!({
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name],
},
"parameters": schema,
},
"required": ["name", "parameters"],
}));
}
// Build the top-level array schema with anyOf constraints
let mut result = json!({
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": any_of,
},
});
// Attach the merged $defs at the root level if any were collected
if !defs.is_empty()
&& let Value::Object(map) = &mut result
{
map.insert(
"$defs".to_string(),
Value::Object(defs.into_iter().collect()),
);
}
Ok(result)
}
/// Holds a tool's parameter schema and its extracted $defs (if any).
///
/// When a tool's parameters reference shared types via `$ref`, those types
/// are defined in a `$defs` section within the schema. We extract them separately
/// to merge into a global definitions map.
struct ParamsAndDefs {
/// The parameter schema with `$defs` removed (if it had one)
schema: Value,
/// Extracted `$defs` map, or None if the schema had no definitions
defs: Option<BTreeMap<String, Value>>,
}
/// Extracts `$defs` from a function's parameter schema, returning both the
/// cleaned schema and the definitions separately.
///
/// # Example
///
/// Input schema:
/// ```json
/// {
/// "type": "object",
/// "properties": {
/// "location": {"$ref": "#/$defs/Location"}
/// },
/// "$defs": {
/// "Location": {
/// "type": "object",
/// "properties": {"city": {"type": "string"}}
/// }
/// }
/// }
/// ```
///
/// Returns:
/// - schema: same as input but with `$defs` removed
/// - defs: `Some({"Location": {...}})`
fn split_defs(function: &FunctionObject) -> Result<ParamsAndDefs, ToolChoiceError> {
let mut schema = clone_parameters(function);
let defs = match &mut schema {
Value::Object(obj) => {
if let Some(value) = obj.remove("$defs") {
Some(convert_defs(function, value)?)
} else {
None
}
}
_ => None,
};
Ok(ParamsAndDefs { schema, defs })
}
fn convert_defs(
function: &FunctionObject,
defs_value: Value,
) -> Result<BTreeMap<String, Value>, ToolChoiceError> {
match defs_value {
Value::Object(map) => Ok(map.into_iter().collect()),
_ => Err(ToolChoiceError::InvalidDefinitionMap(function.name.clone())),
}
}
/// Merges definitions from one tool into the global `$defs` accumulator.
///
/// # Conflict Detection
///
/// If two tools define the same type name but with different schemas, we return
/// an error. This ensures consistency across tool definitions.
///
/// # Example
///
/// If `target` contains:
/// ```json
/// {"Location": {"type": "object", "properties": {"city": {"type": "string"}}}}
/// ```
///
/// And we try to merge:
/// ```json
/// {"Location": {"type": "object", "properties": {"city": {"type": "number"}}}}
/// ```
///
/// This will return `ToolChoiceError::ConflictingDefinition("Location")`.
fn merge_defs(
target: &mut BTreeMap<String, Value>,
defs: Option<BTreeMap<String, Value>>,
) -> Result<(), ToolChoiceError> {
let Some(defs) = defs else {
return Ok(());
};
for (name, schema) in defs {
if let Some(existing) = target.get(&name) {
if existing != &schema {
return Err(ToolChoiceError::ConflictingDefinition(name));
}
} else {
target.insert(name, schema);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, ChatCompletionToolType};
fn sample_tools() -> Vec<ChatCompletionTool> {
vec![
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: "add_numbers".to_string(),
description: Some("Add two integers".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"},
},
"required": ["a", "b"],
})),
strict: None,
},
},
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location", "unit"],
})),
strict: None,
},
},
]
}
#[test]
fn named_choice_returns_parameters() {
let tools = sample_tools();
let tool_choice = ChatCompletionToolChoiceOption::Named(
dynamo_async_openai::types::ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionName {
name: "get_weather".to_string(),
},
},
);
let schema = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).expect("schema");
assert_eq!(
schema.unwrap(),
json!({
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location", "unit"],
})
);
}
#[test]
fn required_choice_builds_any_of_schema() {
let tools = sample_tools();
let schema = get_json_schema_from_tools(
Some(&ChatCompletionToolChoiceOption::Required),
Some(&tools),
)
.expect("schema");
let schema = schema.expect("required schema");
assert_eq!(schema["type"], "array");
assert_eq!(schema["minItems"], 1);
assert!(schema["items"]["anyOf"].is_array());
let any_of = schema["items"]["anyOf"].as_array().unwrap();
assert_eq!(any_of.len(), 2);
assert_eq!(
any_of[0]["properties"]["name"],
json!({"type": "string", "enum": ["add_numbers"]})
);
}
#[test]
fn missing_tool_errors() {
let tools = sample_tools();
let tool_choice = ChatCompletionToolChoiceOption::Named(
dynamo_async_openai::types::ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionName {
name: "unknown".to_string(),
},
},
);
let err = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).unwrap_err();
assert_eq!(err, ToolChoiceError::ToolNotFound("unknown".to_string()));
}
#[test]
fn conflicting_defs_errors() {
let tool = ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: "foo".to_string(),
description: None,
parameters: Some(json!({
"type": "object",
"$defs": {
"shared": {"type": "string"}
}
})),
strict: None,
},
};
let mut tool_with_conflict = tool.clone();
tool_with_conflict.function.parameters = Some(json!({
"type": "object",
"$defs": {
"shared": {"type": "number"}
}
}));
let tools = vec![tool, tool_with_conflict];
let err = build_required_schema(&tools).unwrap_err();
assert_eq!(
err,
ToolChoiceError::ConflictingDefinition("shared".to_string())
);
}
}
...@@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() { ...@@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() {
); );
assert_eq!( assert_eq!(
request.get_guided_json(), request.get_guided_json(),
Some(&serde_json::json!({"key": "value"})) Some(serde_json::json!({"key": "value"}))
); );
// Test guided_regex can be specified at root level // Test guided_regex can be specified at root level
......
...@@ -484,7 +484,8 @@ mod tests { ...@@ -484,7 +484,8 @@ mod tests {
// Step 2: Apply tool calling jail transformation // Step 2: Apply tool calling jail transformation
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
"nemotron_deci".to_string(), Some("nemotron_deci".to_string()),
None, // No tool_choice in this test
reasoning_parsed_stream, reasoning_parsed_stream,
); );
...@@ -596,7 +597,8 @@ mod tests { ...@@ -596,7 +597,8 @@ mod tests {
let reasoning_parsed_stream = stream::iter(debug_chunks); let reasoning_parsed_stream = stream::iter(debug_chunks);
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
"harmony".to_string(), Some("harmony".to_string()),
None, // No tool_choice in this test
reasoning_parsed_stream, reasoning_parsed_stream,
); );
......
...@@ -158,7 +158,8 @@ async fn parse_response_stream( ...@@ -158,7 +158,8 @@ async fn parse_response_stream(
> = if tool_parse_enable { > = if tool_parse_enable {
if let Some(tool_parser) = tool_parser_str { if let Some(tool_parser) = tool_parser_str {
Box::pin(OpenAIPreprocessor::apply_tool_calling_jail( Box::pin(OpenAIPreprocessor::apply_tool_calling_jail(
tool_parser, Some(tool_parser),
None, // No tool_choice in this test
stream, stream,
)) ))
} else { } else {
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::{
ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption,
ChatCompletionToolType, CreateChatCompletionRequest, FunctionName,
};
use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput;
use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
async fn apply_jail_transformation(
raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(vec![Annotated {
data: Some(raw_response),
id: None,
event: None,
comment: None,
}]);
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream.next().await.unwrap().data.unwrap()
}
async fn apply_jail_transformation_streaming(
raw_responses: Vec<
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
>,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> Vec<dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse> {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated {
data: Some(r),
id: None,
event: None,
comment: None,
}));
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream
.filter_map(|ann| async move { ann.data })
.collect()
.await
}
fn build_backend_output(text: &str) -> BackendOutput {
BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(text.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop),
index: Some(0),
completion_usage: None,
disaggregated_params: None,
}
}
#[tokio::test]
async fn test_named_tool_choice_parses_json() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-1".to_string());
let backend_output = build_backend_output(r#"{"location":"Paris"}"#);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let choice = &response.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
);
let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_deref() == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.index, 0);
assert!(tool_call.id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function));
assert_eq!(
tool_call.function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_call.function.as_ref().unwrap().arguments.as_deref(),
Some(r#"{"location":"Paris"}"#)
);
}
#[tokio::test]
async fn test_required_tool_choice_parses_json_array() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-2".to_string());
let backend_output = build_backend_output(
r#"[{"name":"search","parameters":{"query":"rust"}},
{"name":"summarize","parameters":{"topic":"memory"}}]"#,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let choice = &response.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_deref() == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0].index, 0);
assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"query":"rust"}"#)
);
assert_eq!(tool_calls[1].index, 1);
assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function));
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
);
assert_eq!(
tool_calls[1]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"topic":"memory"}"#)
);
}
#[tokio::test]
async fn test_tool_choice_parse_failure_returns_as_content() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-3".to_string());
let backend_output = build_backend_output("not-json");
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let delta = &response.choices[0].delta;
// Jail stream behavior: if parsing fails, return accumulated content as-is
// This matches marker-based FC behavior
assert_eq!(delta.content.as_deref(), Some("not-json"));
assert!(delta.tool_calls.is_none());
}
#[tokio::test]
async fn test_streaming_named_tool_buffers_until_finish() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stream-1".to_string());
let chunks = [r#"{"location":""#, r#"Paris","unit":""#, r#"celsius"}"#];
let mut raw_responses = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(chunk.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: if i == chunks.len() - 1 {
Some(common::FinishReason::Stop)
} else {
None
},
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("streaming chunk");
raw_responses.push(response);
}
let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await;
// Jail stream buffers content until valid JSON, then emits once
assert_eq!(all_responses.len(), 1);
let response = &all_responses[0];
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
);
let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"location":"Paris","unit":"celsius"}"#)
);
}
#[tokio::test]
async fn test_streaming_required_tool_parallel() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stream-2".to_string());
let chunks = [
r#"[{"name":"search","parameters":{"query":"rust"}},"#,
r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#,
];
let mut raw_responses = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(chunk.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: if i == chunks.len() - 1 {
Some(common::FinishReason::Stop)
} else {
None
},
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("streaming chunk");
raw_responses.push(response);
}
let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await;
// Jail stream buffers until complete JSON array
assert_eq!(all_responses.len(), 1);
let response = &all_responses[0];
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"query":"rust"}"#)
);
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
);
assert_eq!(
tool_calls[1]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"topic":"memory"}"#)
);
}
#[test]
fn test_no_tool_choice_outputs_normal_text() {
let request = create_test_request();
let mut generator = request.response_generator("req-stream-4".to_string());
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some("Hello world".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("normal text");
assert_eq!(
response.choices[0].delta.content.as_deref(),
Some("Hello world")
);
assert!(response.choices[0].delta.tool_calls.is_none());
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests for tool_choice finish_reason handling.
use dynamo_async_openai::types::{
ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption,
ChatCompletionToolType, CreateChatCompletionRequest, FunctionName,
};
use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput;
use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> BackendOutput {
BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(text.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(finish),
index: Some(0),
completion_usage: None,
disaggregated_params: None,
}
}
async fn apply_jail_transformation(
raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(vec![Annotated {
data: Some(raw_response),
id: None,
event: None,
comment: None,
}]);
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream.next().await.unwrap().data.unwrap()
}
#[tokio::test]
async fn test_named_tool_choice_preserves_length_finish_reason() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-length-1".to_string());
let backend_output = build_backend_output_with_finish(
r#"{"location":"Par"#, // Incomplete due to length limit
common::FinishReason::Length,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
// Critical: Length finish reason should be preserved, NOT replaced with Stop
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Length),
"Length finish reason must be preserved for tool_choice=named"
);
}
#[test]
fn test_required_tool_choice_preserves_length_finish_reason() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required);
let mut generator = request.response_generator("req-length-2".to_string());
let backend_output = build_backend_output_with_finish(
r#"[{"name":"search","parameters":{"query":"incomplete"#, // Incomplete due to length
common::FinishReason::Length,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Critical: Length finish reason should be preserved, NOT replaced with ToolCalls
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Length),
"Length finish reason must be preserved for tool_choice=required"
);
}
#[test]
fn test_named_tool_choice_preserves_content_filter() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "search".to_string(),
},
},
));
let mut generator = request.response_generator("req-filter-1".to_string());
let backend_output = build_backend_output_with_finish(
r#"{"query":"filtered content"#,
common::FinishReason::ContentFilter,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Critical: ContentFilter finish reason should be preserved
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ContentFilter),
"ContentFilter finish reason must be preserved for tool_choice=named"
);
}
#[test]
fn test_required_tool_choice_preserves_content_filter() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required);
let mut generator = request.response_generator("req-filter-2".to_string());
let backend_output = build_backend_output_with_finish(
r#"[{"name":"harmful_action"#,
common::FinishReason::ContentFilter,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Critical: ContentFilter finish reason should be preserved
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ContentFilter),
"ContentFilter finish reason must be preserved for tool_choice=required"
);
}
#[test]
fn test_named_tool_choice_normal_stop_becomes_stop() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
let mut generator = request.response_generator("req-stop-1".to_string());
let backend_output = build_backend_output_with_finish(
r#"{"location":"Paris","unit":"celsius"}"#,
common::FinishReason::Stop,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Normal completion: Stop should remain Stop for named tool choice
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop),
);
}
#[tokio::test]
async fn test_required_tool_choice_normal_stop_becomes_tool_calls() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stop-2".to_string());
let backend_output = build_backend_output_with_finish(
r#"[{"name":"search","parameters":{"query":"rust"}}]"#,
common::FinishReason::Stop,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
// Normal completion: Stop should become ToolCalls for required tool choice
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment