Unverified Commit 5585f803 authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

feat: add tool_choice support (#4722)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
parent 94d145a9
...@@ -752,25 +752,17 @@ impl OpenAIPreprocessor { ...@@ -752,25 +752,17 @@ impl OpenAIPreprocessor {
has_tools: bool, has_tools: bool,
) -> std::result::Result<bool, Error> { ) -> std::result::Result<bool, Error> {
match (tool_call_parser, tool_choice, has_tools) { match (tool_call_parser, tool_choice, has_tools) {
// No parser but tools requested - error cases // tool_choice=required/named work without parser (use Immediate jail mode)
(None, Some(ChatCompletionToolChoiceOption::Required), true) => { (None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true),
tracing::warn!( (None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true),
"Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
); // tool_choice=auto requires a parser
Ok(false)
}
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => { (None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
tracing::warn!( tracing::warn!(
"Tool choice 'auto' specified but no tool parser configured; proceeding without jailing" "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
); );
Ok(false) Ok(false)
} }
(None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
tracing::warn!(
"Named tool choice specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
}
// Parser exists and tools might be called // Parser exists and tools might be called
(Some(_), Some(ChatCompletionToolChoiceOption::None), _) => { (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
...@@ -786,15 +778,38 @@ impl OpenAIPreprocessor { ...@@ -786,15 +778,38 @@ impl OpenAIPreprocessor {
/// Apply tool calling jail to the stream if needed /// Apply tool calling jail to the stream if needed
pub fn apply_tool_calling_jail<S>( pub fn apply_tool_calling_jail<S>(
tool_call_parser: String, tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
stream: S, stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static, S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{ {
let jail = JailedStream::builder() use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
.tool_call_parser(tool_call_parser)
.build(); let mut builder = JailedStream::builder();
// Configure jail based on tool_choice
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
// Immediate jail mode for named tool choice
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
// Immediate jail mode for required tool choice
builder = builder.tool_choice_required();
}
Some(ChatCompletionToolChoiceOption::Auto)
| Some(ChatCompletionToolChoiceOption::None)
| None => {
// Traditional marker-based jail for auto/none/unspecified
if let Some(parser) = tool_call_parser {
builder = builder.tool_call_parser(parser);
}
}
}
let jail = builder.build();
jail.apply_with_finish_reason(stream) jail.apply_with_finish_reason(stream)
} }
...@@ -957,11 +972,11 @@ impl ...@@ -957,11 +972,11 @@ impl
// Apply jail conditionally // Apply jail conditionally
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail { let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
if let Some(parser) = self.tool_call_parser.clone() { Box::pin(Self::apply_tool_calling_jail(
Box::pin(Self::apply_tool_calling_jail(parser, stream)) self.tool_call_parser.clone(),
} else { request.inner.tool_choice.clone(),
Box::pin(stream) // Should not happen due to should_jail check stream,
} ))
} else { } else {
Box::pin(stream) Box::pin(stream)
}; };
......
...@@ -17,6 +17,7 @@ pub mod embeddings; ...@@ -17,6 +17,7 @@ pub mod embeddings;
pub mod models; pub mod models;
pub mod nvext; pub mod nvext;
pub mod responses; pub mod responses;
pub mod tools;
pub mod validate; pub mod validate;
use validate::{ use validate::{
...@@ -131,7 +132,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -131,7 +132,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let guided_whitespace_pattern = self.get_guided_whitespace_pattern(); let guided_whitespace_pattern = self.get_guided_whitespace_pattern();
let guided_decoding = match common::GuidedDecodingOptions::from_optional( let guided_decoding = match common::GuidedDecodingOptions::from_optional(
guided_json.cloned(), guided_json,
guided_regex, guided_regex,
guided_choice, guided_choice,
guided_grammar, guided_grammar,
......
...@@ -12,7 +12,7 @@ use super::{ ...@@ -12,7 +12,7 @@ use super::{
common_ext::{CommonExt, CommonExtProvider}, common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt, nvext::NvExt,
nvext::NvExtProvider, nvext::NvExtProvider,
validate, tools, validate,
}; };
pub mod aggregator; pub mod aggregator;
...@@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
} }
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> { fn get_guided_json(&self) -> Option<serde_json::Value> {
self.common.guided_json.as_ref() if let Some(value) = self.common.guided_json.clone() {
return Some(value);
}
let tool_choice = self.inner.tool_choice.as_ref()?;
let tools = self.inner.tools.as_deref()?;
match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) {
Ok(schema) => schema,
Err(err) => {
tracing::warn!(
error = %err,
"failed to derive guided_json from tool_choice"
);
None
}
}
} }
fn get_guided_regex(&self) -> Option<String> { fn get_guided_regex(&self) -> Option<String> {
......
...@@ -94,7 +94,7 @@ pub trait CommonExtProvider { ...@@ -94,7 +94,7 @@ pub trait CommonExtProvider {
fn common_ext(&self) -> Option<&CommonExt>; fn common_ext(&self) -> Option<&CommonExt>;
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value>; fn get_guided_json(&self) -> Option<serde_json::Value>;
fn get_guided_regex(&self) -> Option<String>; fn get_guided_regex(&self) -> Option<String>;
fn get_guided_grammar(&self) -> Option<String>; fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>; fn get_guided_choice(&self) -> Option<Vec<String>>;
......
...@@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest {
} }
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> { fn get_guided_json(&self) -> Option<serde_json::Value> {
self.common.guided_json.as_ref() self.common.guided_json.clone()
} }
fn get_guided_regex(&self) -> Option<String> { fn get_guided_regex(&self) -> Option<String> {
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::BTreeMap;
use dynamo_async_openai::types::{
ChatCompletionTool, ChatCompletionToolChoiceOption, FunctionObject,
};
use serde_json::{Value, json};
use thiserror::Error;
/// Errors that can occur when deriving JSON schemas for tool_choice requests.
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ToolChoiceError {
#[error("tool_choice requires a matching `tools` array")]
MissingTools,
#[error("tool `{0}` was not provided in `tools`")]
ToolNotFound(String),
#[error("$defs for tool `{0}` must be an object")]
InvalidDefinitionMap(String),
#[error("duplicate $defs entry `{0}` has conflicting schemas")]
ConflictingDefinition(String),
#[error("tool_choice `required` needs at least one tool definition")]
EmptyTools,
}
/// Builds the JSON schema enforced by Guided Decoding for the given tool_choice/tools pair.
pub fn get_json_schema_from_tools(
tool_choice: Option<&ChatCompletionToolChoiceOption>,
tools: Option<&[ChatCompletionTool]>,
) -> Result<Option<Value>, ToolChoiceError> {
let Some(choice) = tool_choice else {
return Ok(None);
};
match choice {
ChatCompletionToolChoiceOption::None | ChatCompletionToolChoiceOption::Auto => Ok(None),
ChatCompletionToolChoiceOption::Named(named) => {
let tools = tools.ok_or(ToolChoiceError::MissingTools)?;
let tool = find_tool(tools, &named.function.name)
.ok_or_else(|| ToolChoiceError::ToolNotFound(named.function.name.clone()))?;
Ok(Some(clone_parameters(&tool.function)))
}
ChatCompletionToolChoiceOption::Required => {
let tools = tools.ok_or(ToolChoiceError::MissingTools)?;
if tools.is_empty() {
return Err(ToolChoiceError::EmptyTools);
}
build_required_schema(tools).map(Some)
}
}
}
fn find_tool<'a>(tools: &'a [ChatCompletionTool], name: &str) -> Option<&'a ChatCompletionTool> {
tools.iter().find(|tool| tool.function.name == name)
}
fn clone_parameters(function: &FunctionObject) -> Value {
function
.parameters
.clone()
.unwrap_or_else(|| json!({"type": "object", "properties": {}}))
}
/// Builds a JSON Schema for `tool_choice=required` that enforces an array of tool calls.
///
/// # Schema Structure
///
/// The generated schema looks like:
/// ```json
/// {
/// "type": "array",
/// "minItems": 1,
/// "items": {
/// "type": "object",
/// "anyOf": [
/// {
/// "properties": {
/// "name": {"type": "string", "enum": ["tool1"]},
/// "parameters": { /* tool1's parameter schema */ }
/// },
/// "required": ["name", "parameters"]
/// },
/// {
/// "properties": {
/// "name": {"type": "string", "enum": ["tool2"]},
/// "parameters": { /* tool2's parameter schema */ }
/// },
/// "required": ["name", "parameters"]
/// }
/// ]
/// },
/// "$defs": { /* shared type definitions from all tools */ }
/// }
/// ```
///
/// # $defs Handling
///
/// `$defs` contains shared JSON Schema definitions that can be referenced via `$ref`.
/// For example, if two tools reference a common type:
/// ```json
/// {
/// "$defs": {
/// "Location": {
/// "type": "object",
/// "properties": {
/// "city": {"type": "string"},
/// "country": {"type": "string"}
/// }
/// }
/// }
/// }
/// ```
///
/// We extract `$defs` from each tool's schema and merge them into a global `$defs` map
/// at the root level. If multiple tools define the same type, we verify they match to
/// avoid conflicts.
fn build_required_schema(tools: &[ChatCompletionTool]) -> Result<Value, ToolChoiceError> {
// Accumulator for all shared type definitions ($defs) across tools
let mut defs: BTreeMap<String, Value> = BTreeMap::new();
let mut any_of = Vec::with_capacity(tools.len());
for tool in tools {
// Extract parameter schema and its $defs (if any)
let ParamsAndDefs {
schema,
defs: new_defs,
} = split_defs(&tool.function)?;
merge_defs(&mut defs, new_defs)?;
any_of.push(json!({
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name],
},
"parameters": schema,
},
"required": ["name", "parameters"],
}));
}
// Build the top-level array schema with anyOf constraints
let mut result = json!({
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": any_of,
},
});
// Attach the merged $defs at the root level if any were collected
if !defs.is_empty()
&& let Value::Object(map) = &mut result
{
map.insert(
"$defs".to_string(),
Value::Object(defs.into_iter().collect()),
);
}
Ok(result)
}
/// Holds a tool's parameter schema and its extracted $defs (if any).
///
/// When a tool's parameters reference shared types via `$ref`, those types
/// are defined in a `$defs` section within the schema. We extract them separately
/// to merge into a global definitions map.
struct ParamsAndDefs {
/// The parameter schema with `$defs` removed (if it had one)
schema: Value,
/// Extracted `$defs` map, or None if the schema had no definitions
defs: Option<BTreeMap<String, Value>>,
}
/// Extracts `$defs` from a function's parameter schema, returning both the
/// cleaned schema and the definitions separately.
///
/// # Example
///
/// Input schema:
/// ```json
/// {
/// "type": "object",
/// "properties": {
/// "location": {"$ref": "#/$defs/Location"}
/// },
/// "$defs": {
/// "Location": {
/// "type": "object",
/// "properties": {"city": {"type": "string"}}
/// }
/// }
/// }
/// ```
///
/// Returns:
/// - schema: same as input but with `$defs` removed
/// - defs: `Some({"Location": {...}})`
fn split_defs(function: &FunctionObject) -> Result<ParamsAndDefs, ToolChoiceError> {
let mut schema = clone_parameters(function);
let defs = match &mut schema {
Value::Object(obj) => {
if let Some(value) = obj.remove("$defs") {
Some(convert_defs(function, value)?)
} else {
None
}
}
_ => None,
};
Ok(ParamsAndDefs { schema, defs })
}
fn convert_defs(
function: &FunctionObject,
defs_value: Value,
) -> Result<BTreeMap<String, Value>, ToolChoiceError> {
match defs_value {
Value::Object(map) => Ok(map.into_iter().collect()),
_ => Err(ToolChoiceError::InvalidDefinitionMap(function.name.clone())),
}
}
/// Merges definitions from one tool into the global `$defs` accumulator.
///
/// # Conflict Detection
///
/// If two tools define the same type name but with different schemas, we return
/// an error. This ensures consistency across tool definitions.
///
/// # Example
///
/// If `target` contains:
/// ```json
/// {"Location": {"type": "object", "properties": {"city": {"type": "string"}}}}
/// ```
///
/// And we try to merge:
/// ```json
/// {"Location": {"type": "object", "properties": {"city": {"type": "number"}}}}
/// ```
///
/// This will return `ToolChoiceError::ConflictingDefinition("Location")`.
fn merge_defs(
target: &mut BTreeMap<String, Value>,
defs: Option<BTreeMap<String, Value>>,
) -> Result<(), ToolChoiceError> {
let Some(defs) = defs else {
return Ok(());
};
for (name, schema) in defs {
if let Some(existing) = target.get(&name) {
if existing != &schema {
return Err(ToolChoiceError::ConflictingDefinition(name));
}
} else {
target.insert(name, schema);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, ChatCompletionToolType};
fn sample_tools() -> Vec<ChatCompletionTool> {
vec![
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: "add_numbers".to_string(),
description: Some("Add two integers".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"},
},
"required": ["a", "b"],
})),
strict: None,
},
},
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location", "unit"],
})),
strict: None,
},
},
]
}
#[test]
fn named_choice_returns_parameters() {
let tools = sample_tools();
let tool_choice = ChatCompletionToolChoiceOption::Named(
dynamo_async_openai::types::ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionName {
name: "get_weather".to_string(),
},
},
);
let schema = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).expect("schema");
assert_eq!(
schema.unwrap(),
json!({
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location", "unit"],
})
);
}
#[test]
fn required_choice_builds_any_of_schema() {
let tools = sample_tools();
let schema = get_json_schema_from_tools(
Some(&ChatCompletionToolChoiceOption::Required),
Some(&tools),
)
.expect("schema");
let schema = schema.expect("required schema");
assert_eq!(schema["type"], "array");
assert_eq!(schema["minItems"], 1);
assert!(schema["items"]["anyOf"].is_array());
let any_of = schema["items"]["anyOf"].as_array().unwrap();
assert_eq!(any_of.len(), 2);
assert_eq!(
any_of[0]["properties"]["name"],
json!({"type": "string", "enum": ["add_numbers"]})
);
}
#[test]
fn missing_tool_errors() {
let tools = sample_tools();
let tool_choice = ChatCompletionToolChoiceOption::Named(
dynamo_async_openai::types::ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionName {
name: "unknown".to_string(),
},
},
);
let err = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).unwrap_err();
assert_eq!(err, ToolChoiceError::ToolNotFound("unknown".to_string()));
}
#[test]
fn conflicting_defs_errors() {
let tool = ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: "foo".to_string(),
description: None,
parameters: Some(json!({
"type": "object",
"$defs": {
"shared": {"type": "string"}
}
})),
strict: None,
},
};
let mut tool_with_conflict = tool.clone();
tool_with_conflict.function.parameters = Some(json!({
"type": "object",
"$defs": {
"shared": {"type": "number"}
}
}));
let tools = vec![tool, tool_with_conflict];
let err = build_required_schema(&tools).unwrap_err();
assert_eq!(
err,
ToolChoiceError::ConflictingDefinition("shared".to_string())
);
}
}
...@@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() { ...@@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() {
); );
assert_eq!( assert_eq!(
request.get_guided_json(), request.get_guided_json(),
Some(&serde_json::json!({"key": "value"})) Some(serde_json::json!({"key": "value"}))
); );
// Test guided_regex can be specified at root level // Test guided_regex can be specified at root level
......
...@@ -484,7 +484,8 @@ mod tests { ...@@ -484,7 +484,8 @@ mod tests {
// Step 2: Apply tool calling jail transformation // Step 2: Apply tool calling jail transformation
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
"nemotron_deci".to_string(), Some("nemotron_deci".to_string()),
None, // No tool_choice in this test
reasoning_parsed_stream, reasoning_parsed_stream,
); );
...@@ -596,7 +597,8 @@ mod tests { ...@@ -596,7 +597,8 @@ mod tests {
let reasoning_parsed_stream = stream::iter(debug_chunks); let reasoning_parsed_stream = stream::iter(debug_chunks);
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
"harmony".to_string(), Some("harmony".to_string()),
None, // No tool_choice in this test
reasoning_parsed_stream, reasoning_parsed_stream,
); );
......
...@@ -158,7 +158,8 @@ async fn parse_response_stream( ...@@ -158,7 +158,8 @@ async fn parse_response_stream(
> = if tool_parse_enable { > = if tool_parse_enable {
if let Some(tool_parser) = tool_parser_str { if let Some(tool_parser) = tool_parser_str {
Box::pin(OpenAIPreprocessor::apply_tool_calling_jail( Box::pin(OpenAIPreprocessor::apply_tool_calling_jail(
tool_parser, Some(tool_parser),
None, // No tool_choice in this test
stream, stream,
)) ))
} else { } else {
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::{
ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption,
ChatCompletionToolType, CreateChatCompletionRequest, FunctionName,
};
use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput;
use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
async fn apply_jail_transformation(
raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(vec![Annotated {
data: Some(raw_response),
id: None,
event: None,
comment: None,
}]);
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream.next().await.unwrap().data.unwrap()
}
async fn apply_jail_transformation_streaming(
raw_responses: Vec<
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
>,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> Vec<dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse> {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated {
data: Some(r),
id: None,
event: None,
comment: None,
}));
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream
.filter_map(|ann| async move { ann.data })
.collect()
.await
}
fn build_backend_output(text: &str) -> BackendOutput {
BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(text.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop),
index: Some(0),
completion_usage: None,
disaggregated_params: None,
}
}
#[tokio::test]
async fn test_named_tool_choice_parses_json() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-1".to_string());
let backend_output = build_backend_output(r#"{"location":"Paris"}"#);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let choice = &response.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
);
let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_deref() == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.index, 0);
assert!(tool_call.id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function));
assert_eq!(
tool_call.function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_call.function.as_ref().unwrap().arguments.as_deref(),
Some(r#"{"location":"Paris"}"#)
);
}
#[tokio::test]
async fn test_required_tool_choice_parses_json_array() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-2".to_string());
let backend_output = build_backend_output(
r#"[{"name":"search","parameters":{"query":"rust"}},
{"name":"summarize","parameters":{"topic":"memory"}}]"#,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let choice = &response.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_deref() == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0].index, 0);
assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"query":"rust"}"#)
);
assert_eq!(tool_calls[1].index, 1);
assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function));
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
);
assert_eq!(
tool_calls[1]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"topic":"memory"}"#)
);
}
#[tokio::test]
async fn test_tool_choice_parse_failure_returns_as_content() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-3".to_string());
let backend_output = build_backend_output("not-json");
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let delta = &response.choices[0].delta;
// Jail stream behavior: if parsing fails, return accumulated content as-is
// This matches marker-based FC behavior
assert_eq!(delta.content.as_deref(), Some("not-json"));
assert!(delta.tool_calls.is_none());
}
#[tokio::test]
async fn test_streaming_named_tool_buffers_until_finish() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stream-1".to_string());
let chunks = [r#"{"location":""#, r#"Paris","unit":""#, r#"celsius"}"#];
let mut raw_responses = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(chunk.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: if i == chunks.len() - 1 {
Some(common::FinishReason::Stop)
} else {
None
},
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("streaming chunk");
raw_responses.push(response);
}
let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await;
// Jail stream buffers content until valid JSON, then emits once
assert_eq!(all_responses.len(), 1);
let response = &all_responses[0];
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
);
let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"location":"Paris","unit":"celsius"}"#)
);
}
#[tokio::test]
async fn test_streaming_required_tool_parallel() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stream-2".to_string());
let chunks = [
r#"[{"name":"search","parameters":{"query":"rust"}},"#,
r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#,
];
let mut raw_responses = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(chunk.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: if i == chunks.len() - 1 {
Some(common::FinishReason::Stop)
} else {
None
},
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("streaming chunk");
raw_responses.push(response);
}
let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await;
// Jail stream buffers until complete JSON array
assert_eq!(all_responses.len(), 1);
let response = &all_responses[0];
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"query":"rust"}"#)
);
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
);
assert_eq!(
tool_calls[1]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"topic":"memory"}"#)
);
}
#[test]
fn test_no_tool_choice_outputs_normal_text() {
let request = create_test_request();
let mut generator = request.response_generator("req-stream-4".to_string());
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some("Hello world".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("normal text");
assert_eq!(
response.choices[0].delta.content.as_deref(),
Some("Hello world")
);
assert!(response.choices[0].delta.tool_calls.is_none());
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests for tool_choice finish_reason handling.
use dynamo_async_openai::types::{
ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption,
ChatCompletionToolType, CreateChatCompletionRequest, FunctionName,
};
use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput;
use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
unsupported_fields: Default::default(),
}
}
fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> BackendOutput {
BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(text.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(finish),
index: Some(0),
completion_usage: None,
disaggregated_params: None,
}
}
async fn apply_jail_transformation(
raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(vec![Annotated {
data: Some(raw_response),
id: None,
event: None,
comment: None,
}]);
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream.next().await.unwrap().data.unwrap()
}
#[tokio::test]
async fn test_named_tool_choice_preserves_length_finish_reason() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-length-1".to_string());
let backend_output = build_backend_output_with_finish(
r#"{"location":"Par"#, // Incomplete due to length limit
common::FinishReason::Length,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
// Critical: Length finish reason should be preserved, NOT replaced with Stop
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Length),
"Length finish reason must be preserved for tool_choice=named"
);
}
#[test]
fn test_required_tool_choice_preserves_length_finish_reason() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required);
let mut generator = request.response_generator("req-length-2".to_string());
let backend_output = build_backend_output_with_finish(
r#"[{"name":"search","parameters":{"query":"incomplete"#, // Incomplete due to length
common::FinishReason::Length,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Critical: Length finish reason should be preserved, NOT replaced with ToolCalls
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Length),
"Length finish reason must be preserved for tool_choice=required"
);
}
#[test]
fn test_named_tool_choice_preserves_content_filter() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "search".to_string(),
},
},
));
let mut generator = request.response_generator("req-filter-1".to_string());
let backend_output = build_backend_output_with_finish(
r#"{"query":"filtered content"#,
common::FinishReason::ContentFilter,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Critical: ContentFilter finish reason should be preserved
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ContentFilter),
"ContentFilter finish reason must be preserved for tool_choice=named"
);
}
#[test]
fn test_required_tool_choice_preserves_content_filter() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required);
let mut generator = request.response_generator("req-filter-2".to_string());
let backend_output = build_backend_output_with_finish(
r#"[{"name":"harmful_action"#,
common::FinishReason::ContentFilter,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Critical: ContentFilter finish reason should be preserved
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ContentFilter),
"ContentFilter finish reason must be preserved for tool_choice=required"
);
}
#[test]
fn test_named_tool_choice_normal_stop_becomes_stop() {
let mut request = create_test_request();
request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
let mut generator = request.response_generator("req-stop-1".to_string());
let backend_output = build_backend_output_with_finish(
r#"{"location":"Paris","unit":"celsius"}"#,
common::FinishReason::Stop,
);
let response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
// Normal completion: Stop should remain Stop for named tool choice
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop),
);
}
#[tokio::test]
async fn test_required_tool_choice_normal_stop_becomes_tool_calls() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stop-2".to_string());
let backend_output = build_backend_output_with_finish(
r#"[{"name":"search","parameters":{"query":"rust"}}]"#,
common::FinishReason::Stop,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
// Normal completion: Stop should become ToolCalls for required tool choice
assert_eq!(
response.choices[0].finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment