Unverified Commit fd5cc288 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

refactor(3/3): switch dynamo-protocols to upstream async-openai types (#7625)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent d517fb80
This diff is collapsed.
......@@ -44,7 +44,7 @@ dynamo-tokens = { path = "lib/tokens", version = "1.0.0" }
dynamo-memory = { path = "lib/memory", version = "1.0.0" }
dynamo-mocker = { path = "lib/mocker", version = "1.0.0" }
dynamo-kv-router = { path = "lib/kv-router", version = "1.0.0", features = ["metrics", "runtime-protocols"] }
dynamo-protocols = { path = "lib/protocols", version = "1.0.0", features = ["byot"] }
dynamo-protocols = { path = "lib/protocols", version = "1.0.0" }
dynamo-parsers = { path = "lib/parsers", version = "1.0.0" }
fastokens = { version = "0.1.0" }
......
......@@ -75,7 +75,7 @@ class NvCreateImageRequest(BaseModel):
class ImageData(BaseModel):
"""Individual image data in a response.
Matches the flattened Rust Image enum in lib/async-openai/src/types/image.rs.
Matches the flattened Rust Image enum in lib/protocols/src/types/mod.rs.
"""
url: Optional[str] = None
......
......@@ -244,14 +244,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai-macros"
version = "0.1.1"
name = "async-openai"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4"
checksum = "ec08254d61379df136135d3d1ac04301be7699fd7d9e57655c63ac7d650a6922"
dependencies = [
"proc-macro2",
"quote",
"syn",
"bytes",
"derive_builder",
"getrandom 0.3.4",
"serde",
"serde_json",
]
[[package]]
......@@ -460,20 +462,6 @@ dependencies = [
"tower-service",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.17",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "backon"
version = "1.6.0"
......@@ -1704,26 +1692,14 @@ dependencies = [
name = "dynamo-protocols"
version = "1.0.0"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"async-openai",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"url",
"utoipa",
"uuid",
]
......@@ -1980,17 +1956,6 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom 7.1.3",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.74.0"
......@@ -2271,12 +2236,6 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.32"
......@@ -2932,15 +2891,6 @@ dependencies = [
"libc",
]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]]
name = "interpolate_name"
version = "0.2.4"
......@@ -5533,22 +5483,6 @@ dependencies = [
"webpki-roots 1.0.6",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom 7.1.3",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "rgb"
version = "0.8.53"
......@@ -5911,7 +5845,6 @@ version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
......@@ -7360,8 +7293,6 @@ dependencies = [
"quote",
"regex",
"syn",
"url",
"uuid",
]
[[package]]
......
......@@ -131,9 +131,9 @@ dependencies = [
[[package]]
name = "arc-swap"
version = "1.9.0"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6"
checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207"
dependencies = [
"rustversion",
]
......@@ -244,14 +244,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai-macros"
version = "0.1.1"
name = "async-openai"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4"
checksum = "ec08254d61379df136135d3d1ac04301be7699fd7d9e57655c63ac7d650a6922"
dependencies = [
"proc-macro2",
"quote",
"syn",
"bytes",
"derive_builder",
"getrandom 0.3.4",
"serde",
"serde_json",
]
[[package]]
......@@ -460,20 +462,6 @@ dependencies = [
"tower-service",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.17",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "backon"
version = "1.6.0"
......@@ -1719,26 +1707,14 @@ dependencies = [
name = "dynamo-protocols"
version = "1.0.0"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"async-openai",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"url",
"utoipa",
"uuid",
]
......@@ -2027,17 +2003,6 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom 7.1.3",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.74.0"
......@@ -2096,11 +2061,11 @@ dependencies = [
[[package]]
name = "fastrand"
version = "2.3.0"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
checksum = "a043dc74da1e37d6afe657061213aa6f425f855399a11d3463c6ecccc4dfda1f"
dependencies = [
"getrandom 0.2.17",
"getrandom 0.3.4",
]
[[package]]
......@@ -2343,12 +2308,6 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.32"
......@@ -3004,15 +2963,6 @@ dependencies = [
"libc",
]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]]
name = "interpolate_name"
version = "0.2.4"
......@@ -5603,22 +5553,6 @@ dependencies = [
"webpki-roots 1.0.6",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom 7.1.3",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "rgb"
version = "0.8.53"
......@@ -5981,7 +5915,6 @@ version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
......@@ -7430,8 +7363,6 @@ dependencies = [
"quote",
"regex",
"syn",
"url",
"uuid",
]
[[package]]
......
......@@ -182,7 +182,7 @@ pub fn final_response_to_one_chunk_stream(
// Convert FunctionCall to FunctionCallStream if present
#[allow(deprecated)]
let function_call = ch.message.function_call.as_ref().map(|fc| {
dynamo_protocols::types::FunctionCallStream {
dynamo_protocols::types::ChatCompletionStreamResponseDeltaFunctionCall {
name: Some(fc.name.clone()),
arguments: Some(fc.arguments.clone()),
}
......@@ -197,7 +197,7 @@ pub fn final_response_to_one_chunk_stream(
|(i, call)| dynamo_protocols::types::ChatCompletionMessageToolCallChunk {
index: i as u32,
id: Some(call.id.clone()),
r#type: Some(call.r#type.clone()),
r#type: Some(dynamo_protocols::types::FunctionType::Function),
function: Some(dynamo_protocols::types::FunctionCallStream {
name: Some(call.function.name.clone()),
arguments: Some(call.function.arguments.clone()),
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod client;
pub mod service;
This diff is collapsed.
......@@ -1544,6 +1544,7 @@ async fn responses(
temperature: request.inner.temperature,
top_p: request.inner.top_p,
max_output_tokens: request.inner.max_output_tokens,
parallel_tool_calls: request.inner.parallel_tool_calls,
store: request.inner.store,
tools: request.inner.tools.clone(),
tool_choice: request.inner.tool_choice.clone(),
......@@ -1788,11 +1789,6 @@ pub fn validate_response_unsupported_fields(
VALIDATION_PREFIX.to_string() + "`prompt` is not supported.",
));
}
if inner.store == Some(true) {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`store: true` is not supported.",
));
}
None
}
......@@ -1965,6 +1961,9 @@ async fn images(
.map(|m| match m {
dynamo_protocols::types::ImageModel::DallE2 => "dall-e-2".to_string(),
dynamo_protocols::types::ImageModel::DallE3 => "dall-e-3".to_string(),
dynamo_protocols::types::ImageModel::GptImage1 => "gpt-image-1".to_string(),
dynamo_protocols::types::ImageModel::GptImage1dot5 => "gpt-image-1.5".to_string(),
dynamo_protocols::types::ImageModel::GptImage1Mini => "gpt-image-1-mini".to_string(),
dynamo_protocols::types::ImageModel::Other(s) => s.clone(),
})
.unwrap_or_else(|| "diffusion".to_string());
......@@ -2540,6 +2539,17 @@ mod tests {
assert!(result.is_none(), "parallel_tool_calls should be supported");
}
#[test]
fn test_validate_unsupported_fields_accepts_store() {
let mut request = make_base_request();
request.inner.store = Some(true);
let result = validate_response_unsupported_fields(&request);
assert!(
result.is_none(),
"store should be supported for audit opt-in"
);
}
#[test]
fn test_validate_unsupported_fields_detects_flags() {
#[allow(clippy::type_complexity)]
......@@ -2559,7 +2569,6 @@ mod tests {
})
}),
),
("store", Box::new(|r| r.store = Some(true))),
];
for (field, set_field) in unsupported_cases {
......@@ -3290,8 +3299,7 @@ mod tests {
use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta,
ChatCompletionToolType, CreateChatCompletionStreamResponse, FinishReason,
FunctionCallStream,
CreateChatCompletionStreamResponse, FinishReason, FunctionCallStream, FunctionType,
};
use dynamo_runtime::protocols::annotated::Annotated;
......@@ -3444,7 +3452,7 @@ mod tests {
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: id.map(|s| s.to_string()),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: name.map(|s| s.to_string()),
arguments: arguments.map(|s| s.to_string()),
......@@ -3537,7 +3545,7 @@ mod tests {
let tc1 = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some("call_1".to_string()),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some("get_weather".to_string()),
arguments: Some(r#"{"city":"Paris"}"#.to_string()),
......@@ -3546,7 +3554,7 @@ mod tests {
let tc2 = ChatCompletionMessageToolCallChunk {
index: 1,
id: Some("call_2".to_string()),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some("get_time".to_string()),
arguments: Some(r#"{"tz":"UTC"}"#.to_string()),
......@@ -3609,7 +3617,7 @@ mod tests {
let complete = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some("call_complete".to_string()),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some("get_weather".to_string()),
arguments: Some(r#"{"city":"Paris"}"#.to_string()),
......@@ -3618,7 +3626,7 @@ mod tests {
let incomplete = ChatCompletionMessageToolCallChunk {
index: 1,
id: Some("call_partial".to_string()),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some("search".to_string()),
arguments: None, // still streaming
......@@ -3658,7 +3666,7 @@ mod tests {
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
id: Some("call_999".to_string()),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: None,
};
#[allow(deprecated)]
......
......@@ -947,7 +947,6 @@ mod tests {
fn create_mock_response_with_logprobs(
token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
......@@ -984,7 +983,6 @@ mod tests {
fn create_mock_response_with_multiple_choices(
choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
let choices = choices_logprobs
.into_iter()
.enumerate()
......@@ -1339,7 +1337,6 @@ mod tests {
#[test]
fn test_logprob_extractor_with_missing_data() {
// Test with choice that has no logprobs
#[expect(deprecated)]
let response = NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
......
......@@ -732,7 +732,7 @@ mod tests {
use super::*;
use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, ChatCompletionToolType, FunctionCallStream,
ChatCompletionStreamResponseDelta, FunctionCallStream, FunctionType,
};
fn text_chunk(text: &str) -> NvCreateChatCompletionStreamResponse {
......@@ -783,7 +783,7 @@ mod tests {
tool_calls: Some(vec![ChatCompletionMessageToolCallChunk {
index: tc_index,
id: id.map(String::from),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: name.map(String::from),
arguments: args.map(String::from),
......
......@@ -20,7 +20,7 @@ use dynamo_protocols::types::{
ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
ChatCompletionTool, ChatCompletionToolChoiceOption, ChatCompletionToolType, FunctionName,
FunctionObject, ImageUrl, ReasoningContent,
FunctionObject, FunctionType, ImageUrl, ReasoningContent,
};
use uuid::Uuid;
......@@ -312,7 +312,7 @@ fn convert_assistant_blocks(
segments.push(std::mem::take(&mut pending_reasoning));
tool_calls.push(ChatCompletionMessageToolCall {
id: id.clone(),
r#type: ChatCompletionToolType::Function,
r#type: FunctionType::Function,
function: dynamo_protocols::types::FunctionCall {
name: name.clone(),
arguments: serde_json::to_string(input).unwrap_or_default(),
......
......@@ -35,6 +35,7 @@ pub use delta::DeltaGenerator;
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest {
#[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateChatCompletionRequest,
#[serde(flatten, default)]
......
......@@ -4,6 +4,8 @@
use futures::{Stream, StreamExt};
use std::collections::HashMap;
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{
Annotated,
......@@ -75,11 +77,11 @@ fn convert_tool_chunk_to_message_tool_call(
chunk: &dynamo_protocols::types::ChatCompletionMessageToolCallChunk,
) -> Option<dynamo_protocols::types::ChatCompletionMessageToolCall> {
// Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall
if let (Some(id), Some(r#type), Some(function)) = (&chunk.id, &chunk.r#type, &chunk.function) {
if let (Some(id), Some(function)) = (&chunk.id, &chunk.function) {
if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) {
Some(dynamo_protocols::types::ChatCompletionMessageToolCall {
id: id.clone(),
r#type: r#type.clone(),
r#type: dynamo_protocols::types::FunctionType::Function,
function: dynamo_protocols::types::FunctionCall {
name: name.clone(),
arguments: arguments.clone(),
......@@ -120,9 +122,9 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
_parsing_options: ParsingOptions,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
let mut aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// Attempt to unwrap the delta, capturing any errors.
let delta = match delta.ok() {
......@@ -256,6 +258,37 @@ impl DeltaAggregator {
return Err(error);
}
if let Some(parser) = parsing_options.tool_call_parser.as_deref() {
for choice in aggregator.choices.values_mut() {
if choice
.tool_calls
.as_ref()
.is_some_and(|calls| !calls.is_empty())
|| choice.text.is_empty()
{
continue;
}
let (tool_calls, content) =
match try_tool_call_parse_aggregate(&choice.text, Some(parser), None).await {
Ok(result) => result,
Err(error) => {
tracing::debug!(
error = %error,
parser,
"failed to parse aggregated chat tool calls"
);
continue;
}
};
if !tool_calls.is_empty() {
choice.tool_calls = Some(tool_calls);
choice.text = content.unwrap_or_default();
}
}
}
// Extract aggregated choices and sort them by index.
let mut choices: Vec<_> = aggregator
.choices
......@@ -405,7 +438,7 @@ mod tests {
dynamo_protocols::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: Some("test_id".to_string()),
r#type: Some(dynamo_protocols::types::ChatCompletionToolType::Function),
r#type: Some(dynamo_protocols::types::FunctionType::Function),
function: Some(dynamo_protocols::types::FunctionCallStream {
name: tool_calls["name"].as_str().map(|s| s.to_string()),
arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()),
......@@ -788,6 +821,10 @@ mod tests {
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
// Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop
assert_eq!(
......@@ -831,6 +868,10 @@ mod tests {
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
// Verify that finish reason was overridden to ToolCalls despite original being Length
assert_eq!(
......@@ -1073,4 +1114,75 @@ mod tests {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "get_weather");
}
#[tokio::test]
async fn test_parses_aggregated_tool_call_text_into_tool_calls() {
let annotated_delta = create_test_delta(
0,
"<tool_call>\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}\n</tool_call>",
Some(dynamo_protocols::types::Role::Assistant),
Some(dynamo_protocols::types::FinishReason::Stop),
None,
None,
);
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(
stream,
ParsingOptions::new(Some("hermes".to_string()), None),
)
.await;
assert!(result.is_ok());
let response = result.unwrap();
let choice = &response.inner.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_protocols::types::FinishReason::ToolCalls)
);
assert_eq!(choice.message.content, None);
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
assert_eq!(tool_calls[0].function.name, "get_weather");
assert_eq!(tool_calls[0].function.arguments, "{\"location\":\"SF\"}");
}
#[tokio::test]
async fn test_preserves_non_tool_content_when_parsing_aggregated_tool_calls() {
let annotated_delta = create_test_delta(
0,
"hello\n<tool_call>\n{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}\n</tool_call>",
Some(dynamo_protocols::types::Role::Assistant),
Some(dynamo_protocols::types::FinishReason::Stop),
None,
None,
);
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(
stream,
ParsingOptions::new(Some("hermes".to_string()), None),
)
.await;
assert!(result.is_ok());
let response = result.unwrap();
let choice = &response.inner.choices[0];
assert_eq!(
choice.message.content,
Some(ChatCompletionMessageContent::Text("hello".to_string()))
);
assert_eq!(
choice.finish_reason,
Some(dynamo_protocols::types::FinishReason::ToolCalls)
);
assert_eq!(
choice.message.tool_calls.as_ref().unwrap()[0].r#type,
dynamo_protocols::types::FunctionType::Function
);
}
}
......@@ -4,7 +4,7 @@
use async_stream::stream;
use dynamo_protocols::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, Role,
ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, FunctionType, Role,
};
use dynamo_parsers::tool_calling::parsers::get_tool_parser_map;
......@@ -902,7 +902,7 @@ impl JailedStream {
.map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk {
index: (tool_call_offset + idx) as u32,
id: Some(tool_call.id),
r#type: Some(tool_call.r#type),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some(tool_call.function.name),
arguments: Some(tool_call.function.arguments),
......@@ -971,7 +971,7 @@ impl JailedStream {
ChatCompletionMessageToolCallChunk {
index,
id: Some(format!("call-{}", Uuid::new_v4())),
r#type: Some(dynamo_protocols::types::ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some(name),
arguments: Some(arguments),
......
......@@ -27,6 +27,7 @@ pub use delta::DeltaGenerator;
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionRequest {
#[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateCompletionRequest,
#[serde(flatten)]
......@@ -47,6 +48,7 @@ pub struct NvCreateCompletionRequest {
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
#[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateCompletionResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
......
......@@ -15,6 +15,7 @@ pub use nvext::{NvExt, NvExtProvider};
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingRequest {
#[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateEmbeddingRequest,
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -30,6 +31,7 @@ pub struct NvCreateEmbeddingRequest {
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingResponse {
#[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateEmbeddingResponse,
}
......
......@@ -42,6 +42,11 @@ impl NvImagesResponse {
inner: dynamo_protocols::types::ImagesResponse {
created: 0,
data: vec![],
background: None,
output_format: None,
quality: None,
size: None,
usage: None,
},
}
}
......
......@@ -9,6 +9,7 @@
//! `response.output_text.done` -> `response.content_part.done` ->
//! `response.output_item.done` -> `response.completed` -> `[DONE]`
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::response::sse::Event;
......@@ -121,17 +122,8 @@ impl ResponseStreamConverter {
output,
// Echo request params with spec-required defaults for omitted fields
background: Some(false),
frequency_penalty: Some(0.0),
metadata: Some(serde_json::Value::Object(Default::default())),
parallel_tool_calls: Some(true),
presence_penalty: Some(0.0),
// store: false because this branch does not persist responses.
store: self
.api_context
.as_ref()
.map(|ctx| ctx.store)
.or(self.params.store)
.or(Some(false)),
metadata: Some(HashMap::new()),
parallel_tool_calls: self.params.parallel_tool_calls.or(Some(true)),
temperature: self.params.temperature.or(Some(1.0)),
text: Some(self.params.text.clone().unwrap_or(ResponseTextParam {
format: TextResponseFormatConfiguration::Text,
......@@ -158,7 +150,6 @@ impl ResponseStreamConverter {
incomplete_details: None,
instructions: self.params.instructions.clone().map(Instructions::Text),
max_output_tokens: self.params.max_output_tokens,
max_tool_calls: None,
previous_response_id: self
.api_context
.as_ref()
......@@ -250,10 +241,11 @@ impl ResponseStreamConverter {
sequence_number: self.next_seq(),
output_index,
item: OutputItem::Message(OutputMessage {
id: Some(self.message_item_id.clone()),
id: self.message_item_id.clone(),
content: vec![],
role: AssistantRole::Assistant,
status: Some(OutputStatus::InProgress),
phase: None,
status: OutputStatus::InProgress,
}),
},
);
......@@ -333,6 +325,7 @@ impl ResponseStreamConverter {
item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(item_id),
call_id,
namespace: None,
name: fc_name,
arguments: String::new(),
status: Some(OutputStatus::InProgress),
......@@ -398,6 +391,7 @@ impl ResponseStreamConverter {
item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(fc_item_id),
call_id: fc_call_id,
namespace: None,
name: fc_name,
arguments: fc_args,
status: Some(OutputStatus::Completed),
......@@ -450,14 +444,15 @@ impl ResponseStreamConverter {
sequence_number: self.next_seq(),
output_index: self.message_output_index,
item: OutputItem::Message(OutputMessage {
id: Some(self.message_item_id.clone()),
id: self.message_item_id.clone(),
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
})],
role: AssistantRole::Assistant,
status: Some(OutputStatus::Completed),
phase: None,
status: OutputStatus::Completed,
}),
});
events.push(make_sse_event(&item_done));
......@@ -497,6 +492,7 @@ impl ResponseStreamConverter {
item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(item_id),
call_id,
namespace: None,
name: fc_name,
arguments: accumulated_args,
status: Some(OutputStatus::Completed),
......@@ -509,14 +505,15 @@ impl ResponseStreamConverter {
let mut output = Vec::new();
if self.message_started {
output.push(OutputItem::Message(OutputMessage {
id: Some(self.message_item_id.clone()),
id: self.message_item_id.clone(),
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
})],
role: AssistantRole::Assistant,
status: Some(OutputStatus::Completed),
phase: None,
status: OutputStatus::Completed,
}));
}
for fc in &self.function_call_items {
......@@ -524,6 +521,7 @@ impl ResponseStreamConverter {
output.push(OutputItem::FunctionCall(FunctionToolCall {
id: Some(fc.item_id.clone()),
call_id: fc.call_id.clone(),
namespace: None,
name: fc.name.clone(),
arguments: fc.accumulated_args.clone(),
status: Some(OutputStatus::Completed),
......@@ -675,7 +673,7 @@ mod tests {
use crate::protocols::unified::ResponsesContext;
use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, ChatCompletionToolType, FunctionCallStream,
ChatCompletionStreamResponseDelta, FunctionCallStream, FunctionType,
};
fn default_params() -> ResponseParams {
......@@ -684,6 +682,7 @@ mod tests {
temperature: None,
top_p: None,
max_output_tokens: None,
parallel_tool_calls: None,
store: None,
tools: None,
tool_choice: None,
......@@ -714,7 +713,7 @@ mod tests {
tool_calls: Some(vec![ChatCompletionMessageToolCallChunk {
index: tc_index,
id: id.map(String::from),
r#type: Some(ChatCompletionToolType::Function),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: name.map(String::from),
arguments: args.map(String::from),
......@@ -932,7 +931,7 @@ mod tests {
);
}
/// Verify that `with_context` populates `previous_response_id` and `store`
/// Verify that `with_context` populates `previous_response_id`
/// in the generated Response objects.
#[test]
fn test_with_context_enriches_response() {
......@@ -949,16 +948,14 @@ mod tests {
let _ = conv.process_chunk(&text_chunk("Hello"));
let _end_events = conv.emit_end_events();
// Verify the Response object carries the context values through
let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(
response.previous_response_id.as_deref(),
Some("resp_prev_123")
);
assert_eq!(response.store, Some(true));
}
/// Without context, previous_response_id is None and store defaults to false.
/// Without context, previous_response_id is None.
#[test]
fn test_without_context_defaults() {
let params = ResponseParams::default();
......@@ -966,6 +963,17 @@ mod tests {
let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(response.previous_response_id, None);
assert_eq!(response.store, Some(false));
}
#[test]
fn test_stream_response_echoes_parallel_tool_calls() {
let params = ResponseParams {
parallel_tool_calls: Some(false),
..Default::default()
};
let conv = ResponseStreamConverter::new("test-model".into(), params);
let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(response.parallel_tool_calls, Some(false));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment