Unverified Commit fd5cc288 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

refactor(3/3): switch dynamo-protocols to upstream async-openai types (#7625)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent d517fb80
This diff is collapsed.
...@@ -44,7 +44,7 @@ dynamo-tokens = { path = "lib/tokens", version = "1.0.0" } ...@@ -44,7 +44,7 @@ dynamo-tokens = { path = "lib/tokens", version = "1.0.0" }
dynamo-memory = { path = "lib/memory", version = "1.0.0" } dynamo-memory = { path = "lib/memory", version = "1.0.0" }
dynamo-mocker = { path = "lib/mocker", version = "1.0.0" } dynamo-mocker = { path = "lib/mocker", version = "1.0.0" }
dynamo-kv-router = { path = "lib/kv-router", version = "1.0.0", features = ["metrics", "runtime-protocols"] } dynamo-kv-router = { path = "lib/kv-router", version = "1.0.0", features = ["metrics", "runtime-protocols"] }
dynamo-protocols = { path = "lib/protocols", version = "1.0.0", features = ["byot"] } dynamo-protocols = { path = "lib/protocols", version = "1.0.0" }
dynamo-parsers = { path = "lib/parsers", version = "1.0.0" } dynamo-parsers = { path = "lib/parsers", version = "1.0.0" }
fastokens = { version = "0.1.0" } fastokens = { version = "0.1.0" }
......
...@@ -75,7 +75,7 @@ class NvCreateImageRequest(BaseModel): ...@@ -75,7 +75,7 @@ class NvCreateImageRequest(BaseModel):
class ImageData(BaseModel): class ImageData(BaseModel):
"""Individual image data in a response. """Individual image data in a response.
Matches the flattened Rust Image enum in lib/async-openai/src/types/image.rs. Matches the flattened Rust Image enum in lib/protocols/src/types/mod.rs.
""" """
url: Optional[str] = None url: Optional[str] = None
......
...@@ -244,14 +244,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -244,14 +244,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a" checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]] [[package]]
name = "async-openai-macros" name = "async-openai"
version = "0.1.1" version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4" checksum = "ec08254d61379df136135d3d1ac04301be7699fd7d9e57655c63ac7d650a6922"
dependencies = [ dependencies = [
"proc-macro2", "bytes",
"quote", "derive_builder",
"syn", "getrandom 0.3.4",
"serde",
"serde_json",
] ]
[[package]] [[package]]
...@@ -460,20 +462,6 @@ dependencies = [ ...@@ -460,20 +462,6 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.17",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]] [[package]]
name = "backon" name = "backon"
version = "1.6.0" version = "1.6.0"
...@@ -1704,26 +1692,14 @@ dependencies = [ ...@@ -1704,26 +1692,14 @@ dependencies = [
name = "dynamo-protocols" name = "dynamo-protocols"
version = "1.0.0" version = "1.0.0"
dependencies = [ dependencies = [
"async-openai-macros", "async-openai",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder", "derive_builder",
"eventsource-stream",
"futures", "futures",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio",
"tokio-stream",
"tokio-util",
"tracing", "tracing",
"url", "url",
"utoipa",
"uuid", "uuid",
] ]
...@@ -1980,17 +1956,6 @@ dependencies = [ ...@@ -1980,17 +1956,6 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom 7.1.3",
"pin-project-lite",
]
[[package]] [[package]]
name = "exr" name = "exr"
version = "1.74.0" version = "1.74.0"
...@@ -2271,12 +2236,6 @@ version = "0.3.32" ...@@ -2271,12 +2236,6 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.32" version = "0.3.32"
...@@ -2932,15 +2891,6 @@ dependencies = [ ...@@ -2932,15 +2891,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "interpolate_name" name = "interpolate_name"
version = "0.2.4" version = "0.2.4"
...@@ -5533,22 +5483,6 @@ dependencies = [ ...@@ -5533,22 +5483,6 @@ dependencies = [
"webpki-roots 1.0.6", "webpki-roots 1.0.6",
] ]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom 7.1.3",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "rgb" name = "rgb"
version = "0.8.53" version = "0.8.53"
...@@ -5911,7 +5845,6 @@ version = "0.10.3" ...@@ -5911,7 +5845,6 @@ version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [ dependencies = [
"serde",
"zeroize", "zeroize",
] ]
...@@ -7360,8 +7293,6 @@ dependencies = [ ...@@ -7360,8 +7293,6 @@ dependencies = [
"quote", "quote",
"regex", "regex",
"syn", "syn",
"url",
"uuid",
] ]
[[package]] [[package]]
......
...@@ -131,9 +131,9 @@ dependencies = [ ...@@ -131,9 +131,9 @@ dependencies = [
[[package]] [[package]]
name = "arc-swap" name = "arc-swap"
version = "1.9.0" version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207"
dependencies = [ dependencies = [
"rustversion", "rustversion",
] ]
...@@ -244,14 +244,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -244,14 +244,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a" checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]] [[package]]
name = "async-openai-macros" name = "async-openai"
version = "0.1.1" version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4" checksum = "ec08254d61379df136135d3d1ac04301be7699fd7d9e57655c63ac7d650a6922"
dependencies = [ dependencies = [
"proc-macro2", "bytes",
"quote", "derive_builder",
"syn", "getrandom 0.3.4",
"serde",
"serde_json",
] ]
[[package]] [[package]]
...@@ -460,20 +462,6 @@ dependencies = [ ...@@ -460,20 +462,6 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.17",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]] [[package]]
name = "backon" name = "backon"
version = "1.6.0" version = "1.6.0"
...@@ -1719,26 +1707,14 @@ dependencies = [ ...@@ -1719,26 +1707,14 @@ dependencies = [
name = "dynamo-protocols" name = "dynamo-protocols"
version = "1.0.0" version = "1.0.0"
dependencies = [ dependencies = [
"async-openai-macros", "async-openai",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder", "derive_builder",
"eventsource-stream",
"futures", "futures",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio",
"tokio-stream",
"tokio-util",
"tracing", "tracing",
"url", "url",
"utoipa",
"uuid", "uuid",
] ]
...@@ -2027,17 +2003,6 @@ dependencies = [ ...@@ -2027,17 +2003,6 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom 7.1.3",
"pin-project-lite",
]
[[package]] [[package]]
name = "exr" name = "exr"
version = "1.74.0" version = "1.74.0"
...@@ -2096,11 +2061,11 @@ dependencies = [ ...@@ -2096,11 +2061,11 @@ dependencies = [
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.3.0" version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" checksum = "a043dc74da1e37d6afe657061213aa6f425f855399a11d3463c6ecccc4dfda1f"
dependencies = [ dependencies = [
"getrandom 0.2.17", "getrandom 0.3.4",
] ]
[[package]] [[package]]
...@@ -2343,12 +2308,6 @@ version = "0.3.32" ...@@ -2343,12 +2308,6 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.32" version = "0.3.32"
...@@ -3004,15 +2963,6 @@ dependencies = [ ...@@ -3004,15 +2963,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "interpolate_name" name = "interpolate_name"
version = "0.2.4" version = "0.2.4"
...@@ -5603,22 +5553,6 @@ dependencies = [ ...@@ -5603,22 +5553,6 @@ dependencies = [
"webpki-roots 1.0.6", "webpki-roots 1.0.6",
] ]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom 7.1.3",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "rgb" name = "rgb"
version = "0.8.53" version = "0.8.53"
...@@ -5981,7 +5915,6 @@ version = "0.10.3" ...@@ -5981,7 +5915,6 @@ version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [ dependencies = [
"serde",
"zeroize", "zeroize",
] ]
...@@ -7430,8 +7363,6 @@ dependencies = [ ...@@ -7430,8 +7363,6 @@ dependencies = [
"quote", "quote",
"regex", "regex",
"syn", "syn",
"url",
"uuid",
] ]
[[package]] [[package]]
......
...@@ -182,7 +182,7 @@ pub fn final_response_to_one_chunk_stream( ...@@ -182,7 +182,7 @@ pub fn final_response_to_one_chunk_stream(
// Convert FunctionCall to FunctionCallStream if present // Convert FunctionCall to FunctionCallStream if present
#[allow(deprecated)] #[allow(deprecated)]
let function_call = ch.message.function_call.as_ref().map(|fc| { let function_call = ch.message.function_call.as_ref().map(|fc| {
dynamo_protocols::types::FunctionCallStream { dynamo_protocols::types::ChatCompletionStreamResponseDeltaFunctionCall {
name: Some(fc.name.clone()), name: Some(fc.name.clone()),
arguments: Some(fc.arguments.clone()), arguments: Some(fc.arguments.clone()),
} }
...@@ -197,7 +197,7 @@ pub fn final_response_to_one_chunk_stream( ...@@ -197,7 +197,7 @@ pub fn final_response_to_one_chunk_stream(
|(i, call)| dynamo_protocols::types::ChatCompletionMessageToolCallChunk { |(i, call)| dynamo_protocols::types::ChatCompletionMessageToolCallChunk {
index: i as u32, index: i as u32,
id: Some(call.id.clone()), id: Some(call.id.clone()),
r#type: Some(call.r#type.clone()), r#type: Some(dynamo_protocols::types::FunctionType::Function),
function: Some(dynamo_protocols::types::FunctionCallStream { function: Some(dynamo_protocols::types::FunctionCallStream {
name: Some(call.function.name.clone()), name: Some(call.function.name.clone()),
arguments: Some(call.function.arguments.clone()), arguments: Some(call.function.arguments.clone()),
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod client;
pub mod service; pub mod service;
This diff is collapsed.
...@@ -1544,6 +1544,7 @@ async fn responses( ...@@ -1544,6 +1544,7 @@ async fn responses(
temperature: request.inner.temperature, temperature: request.inner.temperature,
top_p: request.inner.top_p, top_p: request.inner.top_p,
max_output_tokens: request.inner.max_output_tokens, max_output_tokens: request.inner.max_output_tokens,
parallel_tool_calls: request.inner.parallel_tool_calls,
store: request.inner.store, store: request.inner.store,
tools: request.inner.tools.clone(), tools: request.inner.tools.clone(),
tool_choice: request.inner.tool_choice.clone(), tool_choice: request.inner.tool_choice.clone(),
...@@ -1788,11 +1789,6 @@ pub fn validate_response_unsupported_fields( ...@@ -1788,11 +1789,6 @@ pub fn validate_response_unsupported_fields(
VALIDATION_PREFIX.to_string() + "`prompt` is not supported.", VALIDATION_PREFIX.to_string() + "`prompt` is not supported.",
)); ));
} }
if inner.store == Some(true) {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`store: true` is not supported.",
));
}
None None
} }
...@@ -1965,6 +1961,9 @@ async fn images( ...@@ -1965,6 +1961,9 @@ async fn images(
.map(|m| match m { .map(|m| match m {
dynamo_protocols::types::ImageModel::DallE2 => "dall-e-2".to_string(), dynamo_protocols::types::ImageModel::DallE2 => "dall-e-2".to_string(),
dynamo_protocols::types::ImageModel::DallE3 => "dall-e-3".to_string(), dynamo_protocols::types::ImageModel::DallE3 => "dall-e-3".to_string(),
dynamo_protocols::types::ImageModel::GptImage1 => "gpt-image-1".to_string(),
dynamo_protocols::types::ImageModel::GptImage1dot5 => "gpt-image-1.5".to_string(),
dynamo_protocols::types::ImageModel::GptImage1Mini => "gpt-image-1-mini".to_string(),
dynamo_protocols::types::ImageModel::Other(s) => s.clone(), dynamo_protocols::types::ImageModel::Other(s) => s.clone(),
}) })
.unwrap_or_else(|| "diffusion".to_string()); .unwrap_or_else(|| "diffusion".to_string());
...@@ -2540,6 +2539,17 @@ mod tests { ...@@ -2540,6 +2539,17 @@ mod tests {
assert!(result.is_none(), "parallel_tool_calls should be supported"); assert!(result.is_none(), "parallel_tool_calls should be supported");
} }
#[test]
fn test_validate_unsupported_fields_accepts_store() {
let mut request = make_base_request();
request.inner.store = Some(true);
let result = validate_response_unsupported_fields(&request);
assert!(
result.is_none(),
"store should be supported for audit opt-in"
);
}
#[test] #[test]
fn test_validate_unsupported_fields_detects_flags() { fn test_validate_unsupported_fields_detects_flags() {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
...@@ -2559,7 +2569,6 @@ mod tests { ...@@ -2559,7 +2569,6 @@ mod tests {
}) })
}), }),
), ),
("store", Box::new(|r| r.store = Some(true))),
]; ];
for (field, set_field) in unsupported_cases { for (field, set_field) in unsupported_cases {
...@@ -3290,8 +3299,7 @@ mod tests { ...@@ -3290,8 +3299,7 @@ mod tests {
use dynamo_protocols::types::{ use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta, ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta,
ChatCompletionToolType, CreateChatCompletionStreamResponse, FinishReason, CreateChatCompletionStreamResponse, FinishReason, FunctionCallStream, FunctionType,
FunctionCallStream,
}; };
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
...@@ -3444,7 +3452,7 @@ mod tests { ...@@ -3444,7 +3452,7 @@ mod tests {
let tool_call = ChatCompletionMessageToolCallChunk { let tool_call = ChatCompletionMessageToolCallChunk {
index: 0, index: 0,
id: id.map(|s| s.to_string()), id: id.map(|s| s.to_string()),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: name.map(|s| s.to_string()), name: name.map(|s| s.to_string()),
arguments: arguments.map(|s| s.to_string()), arguments: arguments.map(|s| s.to_string()),
...@@ -3537,7 +3545,7 @@ mod tests { ...@@ -3537,7 +3545,7 @@ mod tests {
let tc1 = ChatCompletionMessageToolCallChunk { let tc1 = ChatCompletionMessageToolCallChunk {
index: 0, index: 0,
id: Some("call_1".to_string()), id: Some("call_1".to_string()),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: Some("get_weather".to_string()), name: Some("get_weather".to_string()),
arguments: Some(r#"{"city":"Paris"}"#.to_string()), arguments: Some(r#"{"city":"Paris"}"#.to_string()),
...@@ -3546,7 +3554,7 @@ mod tests { ...@@ -3546,7 +3554,7 @@ mod tests {
let tc2 = ChatCompletionMessageToolCallChunk { let tc2 = ChatCompletionMessageToolCallChunk {
index: 1, index: 1,
id: Some("call_2".to_string()), id: Some("call_2".to_string()),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: Some("get_time".to_string()), name: Some("get_time".to_string()),
arguments: Some(r#"{"tz":"UTC"}"#.to_string()), arguments: Some(r#"{"tz":"UTC"}"#.to_string()),
...@@ -3609,7 +3617,7 @@ mod tests { ...@@ -3609,7 +3617,7 @@ mod tests {
let complete = ChatCompletionMessageToolCallChunk { let complete = ChatCompletionMessageToolCallChunk {
index: 0, index: 0,
id: Some("call_complete".to_string()), id: Some("call_complete".to_string()),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: Some("get_weather".to_string()), name: Some("get_weather".to_string()),
arguments: Some(r#"{"city":"Paris"}"#.to_string()), arguments: Some(r#"{"city":"Paris"}"#.to_string()),
...@@ -3618,7 +3626,7 @@ mod tests { ...@@ -3618,7 +3626,7 @@ mod tests {
let incomplete = ChatCompletionMessageToolCallChunk { let incomplete = ChatCompletionMessageToolCallChunk {
index: 1, index: 1,
id: Some("call_partial".to_string()), id: Some("call_partial".to_string()),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: Some("search".to_string()), name: Some("search".to_string()),
arguments: None, // still streaming arguments: None, // still streaming
...@@ -3658,7 +3666,7 @@ mod tests { ...@@ -3658,7 +3666,7 @@ mod tests {
let tool_call = ChatCompletionMessageToolCallChunk { let tool_call = ChatCompletionMessageToolCallChunk {
index: 0, index: 0,
id: Some("call_999".to_string()), id: Some("call_999".to_string()),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: None, function: None,
}; };
#[allow(deprecated)] #[allow(deprecated)]
......
...@@ -947,7 +947,6 @@ mod tests { ...@@ -947,7 +947,6 @@ mod tests {
fn create_mock_response_with_logprobs( fn create_mock_response_with_logprobs(
token_logprobs: Vec<ChatCompletionTokenLogprob>, token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
NvCreateChatCompletionStreamResponse { NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse { inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
...@@ -984,7 +983,6 @@ mod tests { ...@@ -984,7 +983,6 @@ mod tests {
fn create_mock_response_with_multiple_choices( fn create_mock_response_with_multiple_choices(
choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>, choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
let choices = choices_logprobs let choices = choices_logprobs
.into_iter() .into_iter()
.enumerate() .enumerate()
...@@ -1339,7 +1337,6 @@ mod tests { ...@@ -1339,7 +1337,6 @@ mod tests {
#[test] #[test]
fn test_logprob_extractor_with_missing_data() { fn test_logprob_extractor_with_missing_data() {
// Test with choice that has no logprobs // Test with choice that has no logprobs
#[expect(deprecated)]
let response = NvCreateChatCompletionStreamResponse { let response = NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse { inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
......
...@@ -732,7 +732,7 @@ mod tests { ...@@ -732,7 +732,7 @@ mod tests {
use super::*; use super::*;
use dynamo_protocols::types::{ use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk, ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, ChatCompletionToolType, FunctionCallStream, ChatCompletionStreamResponseDelta, FunctionCallStream, FunctionType,
}; };
fn text_chunk(text: &str) -> NvCreateChatCompletionStreamResponse { fn text_chunk(text: &str) -> NvCreateChatCompletionStreamResponse {
...@@ -783,7 +783,7 @@ mod tests { ...@@ -783,7 +783,7 @@ mod tests {
tool_calls: Some(vec![ChatCompletionMessageToolCallChunk { tool_calls: Some(vec![ChatCompletionMessageToolCallChunk {
index: tc_index, index: tc_index,
id: id.map(String::from), id: id.map(String::from),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: name.map(String::from), name: name.map(String::from),
arguments: args.map(String::from), arguments: args.map(String::from),
......
...@@ -20,7 +20,7 @@ use dynamo_protocols::types::{ ...@@ -20,7 +20,7 @@ use dynamo_protocols::types::{
ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
ChatCompletionTool, ChatCompletionToolChoiceOption, ChatCompletionToolType, FunctionName, ChatCompletionTool, ChatCompletionToolChoiceOption, ChatCompletionToolType, FunctionName,
FunctionObject, ImageUrl, ReasoningContent, FunctionObject, FunctionType, ImageUrl, ReasoningContent,
}; };
use uuid::Uuid; use uuid::Uuid;
...@@ -312,7 +312,7 @@ fn convert_assistant_blocks( ...@@ -312,7 +312,7 @@ fn convert_assistant_blocks(
segments.push(std::mem::take(&mut pending_reasoning)); segments.push(std::mem::take(&mut pending_reasoning));
tool_calls.push(ChatCompletionMessageToolCall { tool_calls.push(ChatCompletionMessageToolCall {
id: id.clone(), id: id.clone(),
r#type: ChatCompletionToolType::Function, r#type: FunctionType::Function,
function: dynamo_protocols::types::FunctionCall { function: dynamo_protocols::types::FunctionCall {
name: name.clone(), name: name.clone(),
arguments: serde_json::to_string(input).unwrap_or_default(), arguments: serde_json::to_string(input).unwrap_or_default(),
......
...@@ -35,6 +35,7 @@ pub use delta::DeltaGenerator; ...@@ -35,6 +35,7 @@ pub use delta::DeltaGenerator;
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)] #[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest { pub struct NvCreateChatCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateChatCompletionRequest, pub inner: dynamo_protocols::types::CreateChatCompletionRequest,
#[serde(flatten, default)] #[serde(flatten, default)]
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
Annotated, Annotated,
...@@ -75,11 +77,11 @@ fn convert_tool_chunk_to_message_tool_call( ...@@ -75,11 +77,11 @@ fn convert_tool_chunk_to_message_tool_call(
chunk: &dynamo_protocols::types::ChatCompletionMessageToolCallChunk, chunk: &dynamo_protocols::types::ChatCompletionMessageToolCallChunk,
) -> Option<dynamo_protocols::types::ChatCompletionMessageToolCall> { ) -> Option<dynamo_protocols::types::ChatCompletionMessageToolCall> {
// Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall
if let (Some(id), Some(r#type), Some(function)) = (&chunk.id, &chunk.r#type, &chunk.function) { if let (Some(id), Some(function)) = (&chunk.id, &chunk.function) {
if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) { if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) {
Some(dynamo_protocols::types::ChatCompletionMessageToolCall { Some(dynamo_protocols::types::ChatCompletionMessageToolCall {
id: id.clone(), id: id.clone(),
r#type: r#type.clone(), r#type: dynamo_protocols::types::FunctionType::Function,
function: dynamo_protocols::types::FunctionCall { function: dynamo_protocols::types::FunctionCall {
name: name.clone(), name: name.clone(),
arguments: arguments.clone(), arguments: arguments.clone(),
...@@ -120,9 +122,9 @@ impl DeltaAggregator { ...@@ -120,9 +122,9 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing. /// * `Err(String)` if an error occurs during processing.
pub async fn apply( pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
_parsing_options: ParsingOptions, parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let mut aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// Attempt to unwrap the delta, capturing any errors. // Attempt to unwrap the delta, capturing any errors.
let delta = match delta.ok() { let delta = match delta.ok() {
...@@ -256,6 +258,37 @@ impl DeltaAggregator { ...@@ -256,6 +258,37 @@ impl DeltaAggregator {
return Err(error); return Err(error);
} }
if let Some(parser) = parsing_options.tool_call_parser.as_deref() {
for choice in aggregator.choices.values_mut() {
if choice
.tool_calls
.as_ref()
.is_some_and(|calls| !calls.is_empty())
|| choice.text.is_empty()
{
continue;
}
let (tool_calls, content) =
match try_tool_call_parse_aggregate(&choice.text, Some(parser), None).await {
Ok(result) => result,
Err(error) => {
tracing::debug!(
error = %error,
parser,
"failed to parse aggregated chat tool calls"
);
continue;
}
};
if !tool_calls.is_empty() {
choice.tool_calls = Some(tool_calls);
choice.text = content.unwrap_or_default();
}
}
}
// Extract aggregated choices and sort them by index. // Extract aggregated choices and sort them by index.
let mut choices: Vec<_> = aggregator let mut choices: Vec<_> = aggregator
.choices .choices
...@@ -405,7 +438,7 @@ mod tests { ...@@ -405,7 +438,7 @@ mod tests {
dynamo_protocols::types::ChatCompletionMessageToolCallChunk { dynamo_protocols::types::ChatCompletionMessageToolCallChunk {
index: 0, index: 0,
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
r#type: Some(dynamo_protocols::types::ChatCompletionToolType::Function), r#type: Some(dynamo_protocols::types::FunctionType::Function),
function: Some(dynamo_protocols::types::FunctionCallStream { function: Some(dynamo_protocols::types::FunctionCallStream {
name: tool_calls["name"].as_str().map(|s| s.to_string()), name: tool_calls["name"].as_str().map(|s| s.to_string()),
arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()), arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()),
...@@ -788,6 +821,10 @@ mod tests { ...@@ -788,6 +821,10 @@ mod tests {
assert!(choice.message.tool_calls.is_some()); assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap(); let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
// Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop // Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop
assert_eq!( assert_eq!(
...@@ -831,6 +868,10 @@ mod tests { ...@@ -831,6 +868,10 @@ mod tests {
assert!(choice.message.tool_calls.is_some()); assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap(); let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
// Verify that finish reason was overridden to ToolCalls despite original being Length // Verify that finish reason was overridden to ToolCalls despite original being Length
assert_eq!( assert_eq!(
...@@ -1073,4 +1114,75 @@ mod tests { ...@@ -1073,4 +1114,75 @@ mod tests {
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "get_weather"); assert_eq!(tool_calls[0].function.name, "get_weather");
} }
#[tokio::test]
async fn test_parses_aggregated_tool_call_text_into_tool_calls() {
let annotated_delta = create_test_delta(
0,
"<tool_call>\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}\n</tool_call>",
Some(dynamo_protocols::types::Role::Assistant),
Some(dynamo_protocols::types::FinishReason::Stop),
None,
None,
);
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(
stream,
ParsingOptions::new(Some("hermes".to_string()), None),
)
.await;
assert!(result.is_ok());
let response = result.unwrap();
let choice = &response.inner.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_protocols::types::FinishReason::ToolCalls)
);
assert_eq!(choice.message.content, None);
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
assert_eq!(tool_calls[0].function.name, "get_weather");
assert_eq!(tool_calls[0].function.arguments, "{\"location\":\"SF\"}");
}
#[tokio::test]
async fn test_preserves_non_tool_content_when_parsing_aggregated_tool_calls() {
let annotated_delta = create_test_delta(
0,
"hello\n<tool_call>\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}\n</tool_call>",
Some(dynamo_protocols::types::Role::Assistant),
Some(dynamo_protocols::types::FinishReason::Stop),
None,
None,
);
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(
stream,
ParsingOptions::new(Some("hermes".to_string()), None),
)
.await;
assert!(result.is_ok());
let response = result.unwrap();
let choice = &response.inner.choices[0];
assert_eq!(
choice.message.content,
Some(ChatCompletionMessageContent::Text("hello".to_string()))
);
assert_eq!(
choice.finish_reason,
Some(dynamo_protocols::types::FinishReason::ToolCalls)
);
assert_eq!(
choice.message.tool_calls.as_ref().unwrap()[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
}
} }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use async_stream::stream; use async_stream::stream;
use dynamo_protocols::types::{ use dynamo_protocols::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, Role, ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, FunctionType, Role,
}; };
use dynamo_parsers::tool_calling::parsers::get_tool_parser_map; use dynamo_parsers::tool_calling::parsers::get_tool_parser_map;
...@@ -902,7 +902,7 @@ impl JailedStream { ...@@ -902,7 +902,7 @@ impl JailedStream {
.map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk {
index: (tool_call_offset + idx) as u32, index: (tool_call_offset + idx) as u32,
id: Some(tool_call.id), id: Some(tool_call.id),
r#type: Some(tool_call.r#type), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: Some(tool_call.function.name), name: Some(tool_call.function.name),
arguments: Some(tool_call.function.arguments), arguments: Some(tool_call.function.arguments),
...@@ -971,7 +971,7 @@ impl JailedStream { ...@@ -971,7 +971,7 @@ impl JailedStream {
ChatCompletionMessageToolCallChunk { ChatCompletionMessageToolCallChunk {
index, index,
id: Some(format!("call-{}", Uuid::new_v4())), id: Some(format!("call-{}", Uuid::new_v4())),
r#type: Some(dynamo_protocols::types::ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: Some(name), name: Some(name),
arguments: Some(arguments), arguments: Some(arguments),
......
...@@ -27,6 +27,7 @@ pub use delta::DeltaGenerator; ...@@ -27,6 +27,7 @@ pub use delta::DeltaGenerator;
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)] #[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionRequest { pub struct NvCreateCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateCompletionRequest, pub inner: dynamo_protocols::types::CreateCompletionRequest,
#[serde(flatten)] #[serde(flatten)]
...@@ -47,6 +48,7 @@ pub struct NvCreateCompletionRequest { ...@@ -47,6 +48,7 @@ pub struct NvCreateCompletionRequest {
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)] #[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse { pub struct NvCreateCompletionResponse {
#[serde(flatten)] #[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateCompletionResponse, pub inner: dynamo_protocols::types::CreateCompletionResponse,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>, pub nvext: Option<serde_json::Value>,
......
...@@ -15,6 +15,7 @@ pub use nvext::{NvExt, NvExtProvider}; ...@@ -15,6 +15,7 @@ pub use nvext::{NvExt, NvExtProvider};
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)] #[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingRequest { pub struct NvCreateEmbeddingRequest {
#[serde(flatten)] #[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateEmbeddingRequest, pub inner: dynamo_protocols::types::CreateEmbeddingRequest,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -30,6 +31,7 @@ pub struct NvCreateEmbeddingRequest { ...@@ -30,6 +31,7 @@ pub struct NvCreateEmbeddingRequest {
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)] #[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingResponse { pub struct NvCreateEmbeddingResponse {
#[serde(flatten)] #[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateEmbeddingResponse, pub inner: dynamo_protocols::types::CreateEmbeddingResponse,
} }
......
...@@ -42,6 +42,11 @@ impl NvImagesResponse { ...@@ -42,6 +42,11 @@ impl NvImagesResponse {
inner: dynamo_protocols::types::ImagesResponse { inner: dynamo_protocols::types::ImagesResponse {
created: 0, created: 0,
data: vec![], data: vec![],
background: None,
output_format: None,
quality: None,
size: None,
usage: None,
}, },
} }
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
//! `response.output_text.done` -> `response.content_part.done` -> //! `response.output_text.done` -> `response.content_part.done` ->
//! `response.output_item.done` -> `response.completed` -> `[DONE]` //! `response.output_item.done` -> `response.completed` -> `[DONE]`
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use axum::response::sse::Event; use axum::response::sse::Event;
...@@ -121,17 +122,8 @@ impl ResponseStreamConverter { ...@@ -121,17 +122,8 @@ impl ResponseStreamConverter {
output, output,
// Echo request params with spec-required defaults for omitted fields // Echo request params with spec-required defaults for omitted fields
background: Some(false), background: Some(false),
frequency_penalty: Some(0.0), metadata: Some(HashMap::new()),
metadata: Some(serde_json::Value::Object(Default::default())), parallel_tool_calls: self.params.parallel_tool_calls.or(Some(true)),
parallel_tool_calls: Some(true),
presence_penalty: Some(0.0),
// store: false because this branch does not persist responses.
store: self
.api_context
.as_ref()
.map(|ctx| ctx.store)
.or(self.params.store)
.or(Some(false)),
temperature: self.params.temperature.or(Some(1.0)), temperature: self.params.temperature.or(Some(1.0)),
text: Some(self.params.text.clone().unwrap_or(ResponseTextParam { text: Some(self.params.text.clone().unwrap_or(ResponseTextParam {
format: TextResponseFormatConfiguration::Text, format: TextResponseFormatConfiguration::Text,
...@@ -158,7 +150,6 @@ impl ResponseStreamConverter { ...@@ -158,7 +150,6 @@ impl ResponseStreamConverter {
incomplete_details: None, incomplete_details: None,
instructions: self.params.instructions.clone().map(Instructions::Text), instructions: self.params.instructions.clone().map(Instructions::Text),
max_output_tokens: self.params.max_output_tokens, max_output_tokens: self.params.max_output_tokens,
max_tool_calls: None,
previous_response_id: self previous_response_id: self
.api_context .api_context
.as_ref() .as_ref()
...@@ -250,10 +241,11 @@ impl ResponseStreamConverter { ...@@ -250,10 +241,11 @@ impl ResponseStreamConverter {
sequence_number: self.next_seq(), sequence_number: self.next_seq(),
output_index, output_index,
item: OutputItem::Message(OutputMessage { item: OutputItem::Message(OutputMessage {
id: Some(self.message_item_id.clone()), id: self.message_item_id.clone(),
content: vec![], content: vec![],
role: AssistantRole::Assistant, role: AssistantRole::Assistant,
status: Some(OutputStatus::InProgress), phase: None,
status: OutputStatus::InProgress,
}), }),
}, },
); );
...@@ -333,6 +325,7 @@ impl ResponseStreamConverter { ...@@ -333,6 +325,7 @@ impl ResponseStreamConverter {
item: OutputItem::FunctionCall(FunctionToolCall { item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(item_id), id: Some(item_id),
call_id, call_id,
namespace: None,
name: fc_name, name: fc_name,
arguments: String::new(), arguments: String::new(),
status: Some(OutputStatus::InProgress), status: Some(OutputStatus::InProgress),
...@@ -398,6 +391,7 @@ impl ResponseStreamConverter { ...@@ -398,6 +391,7 @@ impl ResponseStreamConverter {
item: OutputItem::FunctionCall(FunctionToolCall { item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(fc_item_id), id: Some(fc_item_id),
call_id: fc_call_id, call_id: fc_call_id,
namespace: None,
name: fc_name, name: fc_name,
arguments: fc_args, arguments: fc_args,
status: Some(OutputStatus::Completed), status: Some(OutputStatus::Completed),
...@@ -450,14 +444,15 @@ impl ResponseStreamConverter { ...@@ -450,14 +444,15 @@ impl ResponseStreamConverter {
sequence_number: self.next_seq(), sequence_number: self.next_seq(),
output_index: self.message_output_index, output_index: self.message_output_index,
item: OutputItem::Message(OutputMessage { item: OutputItem::Message(OutputMessage {
id: Some(self.message_item_id.clone()), id: self.message_item_id.clone(),
content: vec![OutputMessageContent::OutputText(OutputTextContent { content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(), text: self.accumulated_text.clone(),
annotations: vec![], annotations: vec![],
logprobs: Some(vec![]), logprobs: Some(vec![]),
})], })],
role: AssistantRole::Assistant, role: AssistantRole::Assistant,
status: Some(OutputStatus::Completed), phase: None,
status: OutputStatus::Completed,
}), }),
}); });
events.push(make_sse_event(&item_done)); events.push(make_sse_event(&item_done));
...@@ -497,6 +492,7 @@ impl ResponseStreamConverter { ...@@ -497,6 +492,7 @@ impl ResponseStreamConverter {
item: OutputItem::FunctionCall(FunctionToolCall { item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(item_id), id: Some(item_id),
call_id, call_id,
namespace: None,
name: fc_name, name: fc_name,
arguments: accumulated_args, arguments: accumulated_args,
status: Some(OutputStatus::Completed), status: Some(OutputStatus::Completed),
...@@ -509,14 +505,15 @@ impl ResponseStreamConverter { ...@@ -509,14 +505,15 @@ impl ResponseStreamConverter {
let mut output = Vec::new(); let mut output = Vec::new();
if self.message_started { if self.message_started {
output.push(OutputItem::Message(OutputMessage { output.push(OutputItem::Message(OutputMessage {
id: Some(self.message_item_id.clone()), id: self.message_item_id.clone(),
content: vec![OutputMessageContent::OutputText(OutputTextContent { content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(), text: self.accumulated_text.clone(),
annotations: vec![], annotations: vec![],
logprobs: Some(vec![]), logprobs: Some(vec![]),
})], })],
role: AssistantRole::Assistant, role: AssistantRole::Assistant,
status: Some(OutputStatus::Completed), phase: None,
status: OutputStatus::Completed,
})); }));
} }
for fc in &self.function_call_items { for fc in &self.function_call_items {
...@@ -524,6 +521,7 @@ impl ResponseStreamConverter { ...@@ -524,6 +521,7 @@ impl ResponseStreamConverter {
output.push(OutputItem::FunctionCall(FunctionToolCall { output.push(OutputItem::FunctionCall(FunctionToolCall {
id: Some(fc.item_id.clone()), id: Some(fc.item_id.clone()),
call_id: fc.call_id.clone(), call_id: fc.call_id.clone(),
namespace: None,
name: fc.name.clone(), name: fc.name.clone(),
arguments: fc.accumulated_args.clone(), arguments: fc.accumulated_args.clone(),
status: Some(OutputStatus::Completed), status: Some(OutputStatus::Completed),
...@@ -675,7 +673,7 @@ mod tests { ...@@ -675,7 +673,7 @@ mod tests {
use crate::protocols::unified::ResponsesContext; use crate::protocols::unified::ResponsesContext;
use dynamo_protocols::types::{ use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk, ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, ChatCompletionToolType, FunctionCallStream, ChatCompletionStreamResponseDelta, FunctionCallStream, FunctionType,
}; };
fn default_params() -> ResponseParams { fn default_params() -> ResponseParams {
...@@ -684,6 +682,7 @@ mod tests { ...@@ -684,6 +682,7 @@ mod tests {
temperature: None, temperature: None,
top_p: None, top_p: None,
max_output_tokens: None, max_output_tokens: None,
parallel_tool_calls: None,
store: None, store: None,
tools: None, tools: None,
tool_choice: None, tool_choice: None,
...@@ -714,7 +713,7 @@ mod tests { ...@@ -714,7 +713,7 @@ mod tests {
tool_calls: Some(vec![ChatCompletionMessageToolCallChunk { tool_calls: Some(vec![ChatCompletionMessageToolCallChunk {
index: tc_index, index: tc_index,
id: id.map(String::from), id: id.map(String::from),
r#type: Some(ChatCompletionToolType::Function), r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
name: name.map(String::from), name: name.map(String::from),
arguments: args.map(String::from), arguments: args.map(String::from),
...@@ -932,7 +931,7 @@ mod tests { ...@@ -932,7 +931,7 @@ mod tests {
); );
} }
/// Verify that `with_context` populates `previous_response_id` and `store` /// Verify that `with_context` populates `previous_response_id`
/// in the generated Response objects. /// in the generated Response objects.
#[test] #[test]
fn test_with_context_enriches_response() { fn test_with_context_enriches_response() {
...@@ -949,16 +948,14 @@ mod tests { ...@@ -949,16 +948,14 @@ mod tests {
let _ = conv.process_chunk(&text_chunk("Hello")); let _ = conv.process_chunk(&text_chunk("Hello"));
let _end_events = conv.emit_end_events(); let _end_events = conv.emit_end_events();
// Verify the Response object carries the context values through
let response = conv.make_response(Status::Completed, vec![]); let response = conv.make_response(Status::Completed, vec![]);
assert_eq!( assert_eq!(
response.previous_response_id.as_deref(), response.previous_response_id.as_deref(),
Some("resp_prev_123") Some("resp_prev_123")
); );
assert_eq!(response.store, Some(true));
} }
/// Without context, previous_response_id is None and store defaults to false. /// Without context, previous_response_id is None.
#[test] #[test]
fn test_without_context_defaults() { fn test_without_context_defaults() {
let params = ResponseParams::default(); let params = ResponseParams::default();
...@@ -966,6 +963,17 @@ mod tests { ...@@ -966,6 +963,17 @@ mod tests {
let response = conv.make_response(Status::Completed, vec![]); let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(response.previous_response_id, None); assert_eq!(response.previous_response_id, None);
assert_eq!(response.store, Some(false)); }
#[test]
fn test_stream_response_echoes_parallel_tool_calls() {
let params = ResponseParams {
parallel_tool_calls: Some(false),
..Default::default()
};
let conv = ResponseStreamConverter::new("test-model".into(), params);
let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(response.parallel_tool_calls, Some(false));
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment