Unverified Commit 8cb47d04 authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

feat: responses API compliance with upstream type alignment (#6089)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
parent f8d0a9f9
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod stream_converter;
use dynamo_async_openai::types::responses::{
AssistantRole, FunctionCallOutput, FunctionToolCall, InputContent, InputItem, InputParam,
InputRole, Instructions, Item, MessageItem, OutputItem, OutputMessage, OutputMessageContent,
OutputStatus, OutputTextContent, Response, ResponseTextParam, Role as ResponseRole,
ServiceTier, Status, TextResponseFormatConfiguration, Tool, ToolChoiceOptions, ToolChoiceParam,
Truncation,
};
use dynamo_async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionNamedToolChoice,
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartVideo,
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionTool,
ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest,
FunctionName, FunctionObject, ImageDetail as ChatImageDetail, ImageUrl, VideoUrl,
};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;
use validator::Validate;
use super::chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse};
use super::nvext::{NvExt, NvExtProvider};
use super::{OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider};
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateResponse {
/// Flattened CreateResponse fields (model, input, temperature, etc.)
#[serde(flatten)]
pub inner: dynamo_async_openai::types::responses::CreateResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvResponse {
/// Flattened Response fields.
#[serde(flatten)]
pub inner: dynamo_async_openai::types::responses::Response,
/// NVIDIA extension field for response metadata (worker IDs, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
/// Implements `NvExtProvider` for `NvCreateResponse`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateResponse {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
None
}
}
/// Implements `AnnotationsProvider` for `NvCreateResponse`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateResponse {
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
impl OpenAISamplingOptionsProvider for NvCreateResponse {
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.inner.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
None
}
fn get_presence_penalty(&self) -> Option<f32> {
None
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_seed(&self) -> Option<i64> {
None
}
fn get_n(&self) -> Option<u8> {
None
}
fn get_best_of(&self) -> Option<u8> {
None
}
}
impl OpenAIStopConditionsProvider for NvCreateResponse {
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_output_tokens
}
fn get_min_tokens(&self) -> Option<u32> {
None
}
fn get_stop(&self) -> Option<Vec<String>> {
None
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
// ---------------------------------------------------------------------------
// Responses API -> Chat Completions conversion
// ---------------------------------------------------------------------------
/// Convert a Responses API ImageDetail to the Chat Completions ImageDetail.
fn convert_image_detail(
detail: &dynamo_async_openai::types::responses::ImageDetail,
) -> ChatImageDetail {
match detail {
dynamo_async_openai::types::responses::ImageDetail::Auto => ChatImageDetail::Auto,
dynamo_async_openai::types::responses::ImageDetail::Low => ChatImageDetail::Low,
dynamo_async_openai::types::responses::ImageDetail::High => ChatImageDetail::High,
}
}
/// Convert a slice of InputContent to ChatCompletionRequestUserMessageContent.
fn convert_input_content_to_user_content(
content: &[InputContent],
) -> Result<ChatCompletionRequestUserMessageContent, anyhow::Error> {
// If there's a single InputText, treat as simple text
if content.len() == 1
&& let InputContent::InputText(t) = &content[0]
{
return Ok(ChatCompletionRequestUserMessageContent::Text(
t.text.clone(),
));
}
let mut chat_parts = Vec::with_capacity(content.len());
for part in content {
match part {
InputContent::InputText(t) => {
chat_parts.push(ChatCompletionRequestUserMessageContentPart::Text(
ChatCompletionRequestMessageContentPartText {
text: t.text.clone(),
},
));
}
InputContent::InputImage(img) => {
let url_str = img.image_url.as_deref().unwrap_or_default();
let url = url::Url::parse(url_str)
.map_err(|e| anyhow::anyhow!("Invalid image URL '{}': {}", url_str, e))?;
chat_parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url,
detail: Some(convert_image_detail(&img.detail)),
uuid: None,
},
},
));
}
InputContent::InputVideo(vid) => {
let url = url::Url::parse(&vid.video)
.map_err(|e| anyhow::anyhow!("Invalid video URL '{}': {}", vid.video, e))?;
chat_parts.push(ChatCompletionRequestUserMessageContentPart::VideoUrl(
ChatCompletionRequestMessageContentPartVideo {
video_url: VideoUrl {
url,
detail: None,
uuid: None,
},
},
));
}
InputContent::InputAudio(_) => {
return Err(anyhow::anyhow!("Audio input content is not yet supported"));
}
InputContent::InputFile(_) => {
return Err(anyhow::anyhow!("File input content is not yet supported"));
}
}
}
Ok(ChatCompletionRequestUserMessageContent::Array(chat_parts))
}
/// Convert a slice of InputContent to a plain text string (for system/developer messages).
fn convert_input_content_to_text(content: &[InputContent]) -> String {
// Concatenate all text parts; non-text parts are skipped.
content
.iter()
.filter_map(|p| match p {
InputContent::InputText(t) => Some(t.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
/// Convert InputParam::Items to a Vec of ChatCompletionRequestMessages.
fn convert_input_items_to_messages(
items: &[InputItem],
) -> Result<Vec<ChatCompletionRequestMessage>, anyhow::Error> {
let mut messages = Vec::with_capacity(items.len());
for item in items {
match item {
InputItem::Item(inner_item) => match inner_item {
Item::Message(msg_item) => match msg_item {
MessageItem::Input(msg) => {
let chat_msg = match msg.role {
InputRole::System | InputRole::Developer => {
let text = convert_input_content_to_text(&msg.content);
ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
text,
),
name: None,
},
)
}
InputRole::User => {
let content = convert_input_content_to_user_content(&msg.content)?;
ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content,
name: None,
},
)
}
};
messages.push(chat_msg);
}
MessageItem::Output(out_msg) => {
// Previous assistant output message -> assistant message
let text = out_msg
.content
.iter()
.filter_map(|c| match c {
OutputMessageContent::OutputText(t) => Some(t.text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
messages.push(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
text,
)),
reasoning_content: None,
refusal: None,
name: None,
audio: None,
tool_calls: None,
#[allow(deprecated)]
function_call: None,
},
));
}
},
Item::FunctionCall(fc) => {
// A function call from a previous assistant turn -> assistant message with tool_calls
messages.push(ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: None,
reasoning_content: None,
refusal: None,
name: None,
audio: None,
tool_calls: Some(vec![ChatCompletionMessageToolCall {
id: fc.call_id.clone(),
r#type: ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionCall {
name: fc.name.clone(),
arguments: fc.arguments.clone(),
},
}]),
#[allow(deprecated)]
function_call: None,
},
));
}
Item::FunctionCallOutput(fco) => {
// The output of a function call -> tool message
let output_text = match &fco.output {
FunctionCallOutput::Text(text) => text.clone(),
FunctionCallOutput::Content(parts) => convert_input_content_to_text(parts),
};
messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
content: ChatCompletionRequestToolMessageContent::Text(output_text),
tool_call_id: fco.call_id.clone(),
},
));
}
other => {
tracing::debug!(
"Skipping unsupported input item type during conversion: {:?}",
std::mem::discriminant(other)
);
}
},
InputItem::EasyMessage(easy) => {
// Handle easy input messages based on role
let content_text = match &easy.content {
dynamo_async_openai::types::responses::EasyInputContent::Text(text) => {
text.clone()
}
dynamo_async_openai::types::responses::EasyInputContent::ContentList(parts) => {
convert_input_content_to_text(parts)
}
};
let chat_msg = match easy.role {
ResponseRole::System | ResponseRole::Developer => {
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(content_text),
name: None,
})
}
ResponseRole::User => {
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(content_text),
name: None,
})
}
ResponseRole::Assistant => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
content_text,
)),
reasoning_content: None,
refusal: None,
name: None,
audio: None,
tool_calls: None,
#[allow(deprecated)]
function_call: None,
},
),
};
messages.push(chat_msg);
}
InputItem::ItemReference(_) => {
// Skip item references
}
}
}
Ok(messages)
}
/// Convert Responses API Tool to ChatCompletionTool.
fn convert_tools(tools: &[Tool]) -> Vec<ChatCompletionTool> {
tools
.iter()
.filter_map(|tool| match tool {
Tool::Function(f) => Some(ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: f.name.clone(),
description: f.description.clone(),
parameters: f.parameters.clone(),
strict: f.strict,
},
}),
_ => None, // Only function tools are forwarded to chat completions
})
.collect()
}
/// Convert Responses API ToolChoiceParam to ChatCompletionToolChoiceOption.
fn convert_tool_choice(tc: &ToolChoiceParam) -> ChatCompletionToolChoiceOption {
match tc {
ToolChoiceParam::Mode(mode) => match mode {
ToolChoiceOptions::None => ChatCompletionToolChoiceOption::None,
ToolChoiceOptions::Auto => ChatCompletionToolChoiceOption::Auto,
ToolChoiceOptions::Required => ChatCompletionToolChoiceOption::Required,
},
ToolChoiceParam::Function(f) => {
ChatCompletionToolChoiceOption::Named(ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: f.name.clone(),
},
})
}
ToolChoiceParam::Hosted(_) => {
// Hosted tools are not forwarded to chat completions
ChatCompletionToolChoiceOption::Auto
}
_ => {
// Other tool choice types (AllowedTools, Mcp, Custom, etc.) default to auto
ChatCompletionToolChoiceOption::Auto
}
}
}
impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
type Error = anyhow::Error;
fn try_from(resp: NvCreateResponse) -> Result<Self, Self::Error> {
let mut messages = Vec::new();
// Prepend instructions as system message if present
if let Some(instructions) = &resp.inner.instructions {
messages.push(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(instructions.clone()),
name: None,
},
));
}
// Convert input to messages
match &resp.inner.input {
InputParam::Text(text) => {
messages.push(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(text.clone()),
name: None,
},
));
}
InputParam::Items(items) => {
let item_messages = convert_input_items_to_messages(items)?;
messages.extend(item_messages);
}
}
let top_logprobs = convert_top_logprobs(resp.inner.top_logprobs);
// Convert tools if present
let tools = resp
.inner
.tools
.as_ref()
.map(|t| convert_tools(t))
.filter(|t: &Vec<_>| !t.is_empty());
// Convert tool_choice if present
let tool_choice = resp.inner.tool_choice.as_ref().map(convert_tool_choice);
// Determine stream setting: respect caller's preference, default to true for aggregation
let stream = resp.inner.stream.or(Some(true));
Ok(NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
messages,
model: resp.inner.model.unwrap_or_default(),
temperature: resp.inner.temperature,
top_p: resp.inner.top_p,
max_completion_tokens: resp.inner.max_output_tokens,
top_logprobs,
metadata: resp.inner.metadata,
stream,
tools,
tool_choice,
..Default::default()
},
common: Default::default(),
nvext: resp.nvext,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
})
}
}
fn convert_top_logprobs(input: Option<u8>) -> Option<u8> {
input.map(|x| x.min(20))
}
/// Parse `<tool_call>` blocks from model text output.
/// Returns a list of (name, arguments_json) tuples.
/// Returns an empty vec immediately if no `<tool_call>` tag is present.
fn parse_tool_call_text(text: &str) -> Vec<(String, String)> {
if !text.contains("<tool_call>") {
return Vec::new();
}
let mut results = Vec::new();
let mut search_start = 0;
while let Some(start) = text[search_start..].find("<tool_call>") {
let abs_start = search_start + start + "<tool_call>".len();
if let Some(end) = text[abs_start..].find("</tool_call>") {
let block = text[abs_start..abs_start + end].trim();
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(block) {
let name = parsed
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let arguments = if let Some(args) = parsed.get("arguments") {
if args.is_string() {
args.as_str().unwrap_or("{}").to_string()
} else {
serde_json::to_string(args).unwrap_or_else(|_| "{}".to_string())
}
} else {
"{}".to_string()
};
if !name.is_empty() {
results.push((name, arguments));
}
}
search_start = abs_start + end + "</tool_call>".len();
} else {
break;
}
}
results
}
/// Strip `<tool_call>...</tool_call>` blocks and any `<think>...</think>` blocks from text.
/// Returns the original string (no allocation) if no tags are present.
fn strip_tool_call_text(text: &str) -> std::borrow::Cow<'_, str> {
let has_tool = text.contains("<tool_call>");
let has_think = text.contains("<think>");
if !has_tool && !has_think {
return std::borrow::Cow::Borrowed(text);
}
fn strip_tag(input: &mut String, open: &str, close: &str) {
while let Some(start) = input.find(open) {
if let Some(end_offset) = input[start..].find(close) {
input.replace_range(start..start + end_offset + close.len(), "");
} else {
input.truncate(start);
break;
}
}
}
let mut result = text.to_string();
if has_tool {
strip_tag(&mut result, "<tool_call>", "</tool_call>");
}
if has_think {
strip_tag(&mut result, "<think>", "</think>");
}
std::borrow::Cow::Owned(result)
}
// ---------------------------------------------------------------------------
// Chat Completions -> Responses API response conversion
// ---------------------------------------------------------------------------
/// Request parameters to echo back in Response objects.
/// Extracted from the incoming CreateResponse request so that
/// response objects reflect actual request values.
#[derive(Clone, Debug, Default)]
pub struct ResponseParams {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub max_output_tokens: Option<u32>,
pub store: Option<bool>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoiceParam>,
pub instructions: Option<String>,
}
/// Normalize tools so that `FunctionTool.strict` is always set.
/// The upstream type uses `skip_serializing_if = "Option::is_none"` on `strict`,
/// so `None` causes the field to be omitted during JSON serialization.
/// Schema validators (Zod, etc.) expect `strict` to always be present.
/// OpenAI defaults `strict` to `true`.
pub(super) fn normalize_tools(tools: Vec<Tool>) -> Vec<Tool> {
tools
.into_iter()
.map(|tool| match tool {
Tool::Function(mut ft) => {
if ft.strict.is_none() {
ft.strict = Some(true);
}
Tool::Function(ft)
}
other => other,
})
.collect()
}
/// Build an assistant text message output item.
fn make_text_message(id: String, text: String) -> OutputItem {
OutputItem::Message(OutputMessage {
id,
role: AssistantRole::Assistant,
status: OutputStatus::Completed,
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text,
annotations: vec![],
logprobs: Some(vec![]),
})],
})
}
/// Build a function call output item with generated IDs.
fn make_function_call(name: String, arguments: String) -> OutputItem {
OutputItem::FunctionCall(FunctionToolCall {
arguments,
call_id: format!("call_{}", Uuid::new_v4().simple()),
name,
id: Some(format!("fc_{}", Uuid::new_v4().simple())),
status: Some(OutputStatus::Completed),
})
}
/// Convert a ChatCompletion response into a Responses API response object,
/// echoing back the actual request parameters from `params`.
pub fn chat_completion_to_response(
nv_resp: NvCreateChatCompletionResponse,
params: &ResponseParams,
) -> Result<NvResponse, anyhow::Error> {
let chat_resp = nv_resp;
let nvext = chat_resp.nvext.clone();
let message_id = format!("msg_{}", Uuid::new_v4().simple());
let response_id = format!("resp_{}", Uuid::new_v4().simple());
let choice = chat_resp.choices.into_iter().next();
let mut output = Vec::new();
if let Some(choice) = choice {
// Handle structured tool calls
if let Some(tool_calls) = choice.message.tool_calls {
for tc in &tool_calls {
output.push(OutputItem::FunctionCall(FunctionToolCall {
arguments: tc.function.arguments.clone(),
call_id: tc.id.clone(),
name: tc.function.name.clone(),
id: Some(format!("fc_{}", Uuid::new_v4().simple())),
status: Some(OutputStatus::Completed),
}));
}
}
// Handle text content -- also parse <tool_call> blocks from models
// that emit tool calls as text (e.g. Qwen3)
let content_text = match choice.message.content {
Some(dynamo_async_openai::types::ChatCompletionMessageContent::Text(text)) => {
Some(text)
}
Some(dynamo_async_openai::types::ChatCompletionMessageContent::Parts(_)) => {
tracing::warn!(
"Multimodal content in responses API not yet supported, using placeholder"
);
Some("[multimodal content]".to_string())
}
None => None,
};
if let Some(content_text) = content_text
&& !content_text.is_empty()
{
let parsed_calls = parse_tool_call_text(&content_text);
if !parsed_calls.is_empty() {
for (name, arguments) in parsed_calls {
output.push(make_function_call(name, arguments));
}
let remaining = strip_tool_call_text(&content_text);
if !remaining.trim().is_empty() {
output.push(make_text_message(
message_id.clone(),
remaining.into_owned(),
));
}
} else {
output.push(make_text_message(message_id.clone(), content_text));
}
}
if output.is_empty() {
output.push(make_text_message(message_id, String::new()));
}
} else {
tracing::warn!("No choices in chat completion response, using empty content");
output.push(make_text_message(message_id, String::new()));
}
let created_at = chat_resp.created as u64;
let response = Response {
id: response_id,
object: "response".to_string(),
created_at,
completed_at: Some(created_at),
model: chat_resp.model,
status: Status::Completed,
output,
// Spec-required defaults (OpenResponses requires these as non-null)
background: Some(false),
frequency_penalty: Some(0.0),
metadata: Some(serde_json::Value::Object(Default::default())),
parallel_tool_calls: Some(true),
presence_penalty: Some(0.0),
// Echo actual request values, falling back to spec defaults.
// store: false because this branch does not persist responses.
store: params.store.or(Some(false)),
temperature: params.temperature.or(Some(1.0)),
text: Some(ResponseTextParam {
format: TextResponseFormatConfiguration::Text,
verbosity: None,
}),
tool_choice: params
.tool_choice
.clone()
.or(Some(ToolChoiceParam::Mode(ToolChoiceOptions::Auto))),
tools: Some(
params
.tools
.clone()
.map(normalize_tools)
.unwrap_or_default(),
),
top_p: params.top_p.or(Some(1.0)),
truncation: Some(Truncation::Disabled),
// Nullable but required to be present (null is valid)
billing: None,
conversation: None,
error: None,
incomplete_details: None,
instructions: params.instructions.clone().map(Instructions::Text),
max_output_tokens: params.max_output_tokens,
max_tool_calls: None,
previous_response_id: None,
prompt: None,
prompt_cache_key: None,
prompt_cache_retention: None,
reasoning: None,
safety_identifier: None,
service_tier: Some(ServiceTier::Auto),
top_logprobs: Some(0),
usage: None,
};
Ok(NvResponse {
inner: response,
nvext,
})
}
#[cfg(test)]
mod tests {
use dynamo_async_openai::types::responses::{
CreateResponse, FunctionCallOutput, FunctionCallOutputItemParam, FunctionTool,
FunctionToolCall, ImageDetail, InputContent, InputImageContent, InputItem, InputMessage,
InputParam, InputRole, InputTextContent, Item, MessageItem, Tool,
};
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent,
};
use super::*;
use crate::types::openai::chat_completions::NvCreateChatCompletionResponse;
fn make_response_with_input(text: &str) -> NvCreateResponse {
NvCreateResponse {
inner: CreateResponse {
input: InputParam::Text(text.into()),
model: Some("test-model".into()),
max_output_tokens: Some(1024),
temperature: Some(0.5),
top_p: Some(0.9),
top_logprobs: Some(15),
..Default::default()
},
nvext: Some(NvExt {
annotations: Some(vec!["debug".into(), "trace".into()]),
..Default::default()
}),
}
}
#[test]
fn test_annotations_trait_behavior() {
let req = make_response_with_input("hello");
assert_eq!(
req.annotations(),
Some(vec!["debug".to_string(), "trace".to_string()])
);
assert!(req.has_annotation("debug"));
assert!(req.has_annotation("trace"));
assert!(!req.has_annotation("missing"));
}
#[test]
fn test_openai_sampling_trait_behavior() {
let req = make_response_with_input("hello");
assert_eq!(req.get_temperature(), Some(0.5));
assert_eq!(req.get_top_p(), Some(0.9));
assert_eq!(req.get_frequency_penalty(), None);
assert_eq!(req.get_presence_penalty(), None);
}
#[test]
fn test_openai_stop_conditions_trait_behavior() {
let req = make_response_with_input("hello");
assert_eq!(req.get_max_tokens(), Some(1024));
assert_eq!(req.get_min_tokens(), None);
assert_eq!(req.get_stop(), None);
}
#[test]
fn test_into_nvcreate_chat_completion_request() {
let nv_req: NvCreateChatCompletionRequest =
make_response_with_input("hi there").try_into().unwrap();
assert_eq!(nv_req.inner.model, "test-model");
assert_eq!(nv_req.inner.temperature, Some(0.5));
assert_eq!(nv_req.inner.top_p, Some(0.9));
assert_eq!(nv_req.inner.max_completion_tokens, Some(1024));
assert_eq!(nv_req.inner.top_logprobs, Some(15));
assert_eq!(nv_req.inner.stream, Some(true));
let messages = &nv_req.inner.messages;
assert_eq!(messages.len(), 1);
match &messages[0] {
ChatCompletionRequestMessage::User(user_msg) => match &user_msg.content {
ChatCompletionRequestUserMessageContent::Text(t) => {
assert_eq!(t, "hi there");
}
_ => panic!("unexpected user content type"),
},
_ => panic!("expected user message"),
}
}
#[test]
fn test_instructions_prepended_as_system_message() {
let req = NvCreateResponse {
inner: CreateResponse {
input: InputParam::Text("hello".into()),
model: Some("test-model".into()),
instructions: Some("You are a helpful assistant.".into()),
..Default::default()
},
nvext: None,
};
let chat_req: NvCreateChatCompletionRequest = req.try_into().unwrap();
let messages = &chat_req.inner.messages;
assert_eq!(messages.len(), 2);
match &messages[0] {
ChatCompletionRequestMessage::System(sys) => match &sys.content {
ChatCompletionRequestSystemMessageContent::Text(t) => {
assert_eq!(t, "You are a helpful assistant.");
}
_ => panic!("expected text content"),
},
_ => panic!("expected system message first"),
}
}
#[test]
fn test_input_items_multi_turn() {
let req = NvCreateResponse {
inner: CreateResponse {
input: InputParam::Items(vec![
InputItem::Item(Item::Message(MessageItem::Input(InputMessage {
content: vec![InputContent::InputText(InputTextContent {
text: "Be concise.".into(),
})],
role: InputRole::System,
status: None,
}))),
InputItem::Item(Item::Message(MessageItem::Input(InputMessage {
content: vec![InputContent::InputText(InputTextContent {
text: "What is 2+2?".into(),
})],
role: InputRole::User,
status: None,
}))),
InputItem::Item(Item::Message(MessageItem::Output(OutputMessage {
id: "msg_1".into(),
role: AssistantRole::Assistant,
status: OutputStatus::Completed,
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: "4".into(),
annotations: vec![],
logprobs: None,
})],
}))),
InputItem::Item(Item::Message(MessageItem::Input(InputMessage {
content: vec![InputContent::InputText(InputTextContent {
text: "And 3+3?".into(),
})],
role: InputRole::User,
status: None,
}))),
]),
model: Some("test-model".into()),
..Default::default()
},
nvext: None,
};
let chat_req: NvCreateChatCompletionRequest = req.try_into().unwrap();
let messages = &chat_req.inner.messages;
assert_eq!(messages.len(), 4);
assert!(matches!(
messages[0],
ChatCompletionRequestMessage::System(_)
));
assert!(matches!(messages[1], ChatCompletionRequestMessage::User(_)));
assert!(matches!(
messages[2],
ChatCompletionRequestMessage::Assistant(_)
));
assert!(matches!(messages[3], ChatCompletionRequestMessage::User(_)));
}
#[test]
fn test_input_items_with_image() {
let req = NvCreateResponse {
inner: CreateResponse {
input: InputParam::Items(vec![InputItem::Item(Item::Message(MessageItem::Input(
InputMessage {
content: vec![
InputContent::InputText(InputTextContent {
text: "What is in this image?".into(),
}),
InputContent::InputImage(InputImageContent {
detail: ImageDetail::Auto,
file_id: None,
image_url: Some("https://example.com/cat.jpg".into()),
}),
],
role: InputRole::User,
status: None,
},
)))]),
model: Some("test-model".into()),
..Default::default()
},
nvext: None,
};
let chat_req: NvCreateChatCompletionRequest = req.try_into().unwrap();
let messages = &chat_req.inner.messages;
assert_eq!(messages.len(), 1);
match &messages[0] {
ChatCompletionRequestMessage::User(u) => match &u.content {
ChatCompletionRequestUserMessageContent::Array(parts) => {
assert_eq!(parts.len(), 2);
}
_ => panic!("expected array content"),
},
_ => panic!("expected user message"),
}
}
#[test]
fn test_function_call_input_items() {
let req = NvCreateResponse {
inner: CreateResponse {
input: InputParam::Items(vec![
InputItem::Item(Item::Message(MessageItem::Input(InputMessage {
content: vec![InputContent::InputText(InputTextContent {
text: "What's the weather?".into(),
})],
role: InputRole::User,
status: None,
}))),
InputItem::Item(Item::FunctionCall(FunctionToolCall {
arguments: r#"{"location":"SF"}"#.into(),
call_id: "call_123".into(),
name: "get_weather".into(),
id: None,
status: None,
})),
InputItem::Item(Item::FunctionCallOutput(FunctionCallOutputItemParam {
call_id: "call_123".into(),
output: FunctionCallOutput::Text(r#"{"temp":"72F"}"#.into()),
id: None,
status: None,
})),
]),
model: Some("test-model".into()),
..Default::default()
},
nvext: None,
};
let chat_req: NvCreateChatCompletionRequest = req.try_into().unwrap();
let messages = &chat_req.inner.messages;
assert_eq!(messages.len(), 3);
assert!(matches!(messages[0], ChatCompletionRequestMessage::User(_)));
assert!(matches!(
messages[1],
ChatCompletionRequestMessage::Assistant(_)
));
assert!(matches!(messages[2], ChatCompletionRequestMessage::Tool(_)));
}
#[test]
fn test_tools_conversion() {
let req = NvCreateResponse {
inner: CreateResponse {
input: InputParam::Text("hello".into()),
model: Some("test-model".into()),
tools: Some(vec![Tool::Function(FunctionTool {
name: "get_weather".into(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
})),
strict: Some(true),
description: Some("Get weather info".into()),
})]),
..Default::default()
},
nvext: None,
};
let chat_req: NvCreateChatCompletionRequest = req.try_into().unwrap();
assert!(chat_req.inner.tools.is_some());
let tools = chat_req.inner.tools.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
}
#[allow(deprecated)]
#[test]
fn test_into_nvresponse_from_chat_response() {
let now = 1_726_000_000;
let chat_resp = NvCreateChatCompletionResponse {
id: "chatcmpl-xyz".into(),
choices: vec![dynamo_async_openai::types::ChatChoice {
index: 0,
message: dynamo_async_openai::types::ChatCompletionResponseMessage {
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"This is a reply".to_string(),
),
),
refusal: None,
tool_calls: None,
role: dynamo_async_openai::types::Role::Assistant,
function_call: None,
audio: None,
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
}],
created: now,
model: "llama-3.1-8b-instruct".into(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion".to_string(),
usage: None,
nvext: None,
};
let wrapped = chat_completion_to_response(chat_resp, &ResponseParams::default()).unwrap();
assert_eq!(wrapped.inner.model, "llama-3.1-8b-instruct");
assert_eq!(wrapped.inner.status, Status::Completed);
assert_eq!(wrapped.inner.object, "response");
assert!(wrapped.inner.id.starts_with("resp_"));
let msg = match &wrapped.inner.output[0] {
OutputItem::Message(m) => m,
_ => panic!("Expected Message variant"),
};
assert_eq!(msg.role, AssistantRole::Assistant);
match &msg.content[0] {
OutputMessageContent::OutputText(txt) => {
assert_eq!(txt.text, "This is a reply");
}
_ => panic!("Expected OutputText content"),
}
}
#[allow(deprecated)]
#[test]
fn test_response_with_tool_calls() {
let now = 1_726_000_000;
let chat_resp = NvCreateChatCompletionResponse {
id: "chatcmpl-xyz".into(),
choices: vec![dynamo_async_openai::types::ChatChoice {
index: 0,
message: dynamo_async_openai::types::ChatCompletionResponseMessage {
content: None,
refusal: None,
tool_calls: Some(vec![ChatCompletionMessageToolCall {
id: "call_abc".into(),
r#type: ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionCall {
name: "get_weather".into(),
arguments: r#"{"location":"SF"}"#.into(),
},
}]),
role: dynamo_async_openai::types::Role::Assistant,
function_call: None,
audio: None,
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
}],
created: now,
model: "test-model".into(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion".to_string(),
usage: None,
nvext: None,
};
let wrapped = chat_completion_to_response(chat_resp, &ResponseParams::default()).unwrap();
assert_eq!(wrapped.inner.output.len(), 1);
match &wrapped.inner.output[0] {
OutputItem::FunctionCall(fc) => {
assert_eq!(fc.call_id, "call_abc");
assert_eq!(fc.name, "get_weather");
}
_ => panic!("Expected FunctionCall output"),
}
}
#[test]
fn test_convert_top_logprobs_clamped() {
assert_eq!(convert_top_logprobs(Some(5)), Some(5));
assert_eq!(convert_top_logprobs(Some(21)), Some(20));
assert_eq!(convert_top_logprobs(Some(255)), Some(20));
assert_eq!(convert_top_logprobs(None), None);
}
#[test]
fn test_parse_tool_call_text() {
// Standard Qwen3 format
let text = r#"<think>
Let me check the weather.
</think>
<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
</tool_call>"#;
let calls = parse_tool_call_text(text);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].0, "get_weather");
let args: serde_json::Value = serde_json::from_str(&calls[0].1).unwrap();
assert_eq!(args["location"], "San Francisco");
}
#[test]
fn test_parse_tool_call_text_multiple() {
let text = r#"<tool_call>
{"name": "func_a", "arguments": {"x": 1}}
</tool_call>
<tool_call>
{"name": "func_b", "arguments": {"y": 2}}
</tool_call>"#;
let calls = parse_tool_call_text(text);
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].0, "func_a");
assert_eq!(calls[1].0, "func_b");
}
#[test]
fn test_parse_tool_call_text_no_calls() {
let text = "Just a regular message with no tool calls.";
let calls = parse_tool_call_text(text);
assert!(calls.is_empty());
}
#[test]
fn test_strip_tool_call_text() {
let text = r#"<think>
thinking
</think>
<tool_call>
{"name": "f", "arguments": {}}
</tool_call>"#;
let stripped = strip_tool_call_text(text);
assert!(!stripped.contains("<tool_call>"));
assert!(!stripped.contains("<think>"));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Converts a stream of chat completion SSE chunks into Responses API SSE events.
//!
//! The event sequence follows the OpenAI Responses API streaming spec:
//! `response.created` -> `response.in_progress` -> `response.output_item.added` ->
//! `response.content_part.added` -> N x `response.output_text.delta` ->
//! `response.output_text.done` -> `response.content_part.done` ->
//! `response.output_item.done` -> `response.completed` -> `[DONE]`
use std::time::{SystemTime, UNIX_EPOCH};
use axum::response::sse::Event;
use dynamo_async_openai::types::responses::{
AssistantRole, FunctionToolCall, Instructions, OutputContent, OutputItem, OutputMessage,
OutputMessageContent, OutputStatus, OutputTextContent, Response, ResponseCompletedEvent,
ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseCreatedEvent,
ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent, ResponseInProgressEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseStreamEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextParam, ServiceTier, Status, TextResponseFormatConfiguration,
ToolChoiceOptions, ToolChoiceParam, Truncation,
};
use uuid::Uuid;
use dynamo_async_openai::types::ChatCompletionMessageContent;
use super::ResponseParams;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
/// State machine that converts a chat completion stream into Responses API events.
pub struct ResponseStreamConverter {
response_id: String,
model: String,
params: ResponseParams,
created_at: u64,
sequence_number: u64,
// Text message tracking
message_item_id: String,
message_started: bool,
message_output_index: u32,
accumulated_text: String,
// Function call tracking
function_call_items: Vec<FunctionCallState>,
// Output index counter
next_output_index: u32,
}
struct FunctionCallState {
item_id: String,
call_id: String,
name: String,
accumulated_args: String,
output_index: u32,
started: bool,
}
impl ResponseStreamConverter {
pub fn new(model: String, params: ResponseParams) -> Self {
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
response_id: format!("resp_{}", Uuid::new_v4().simple()),
model,
params,
created_at,
sequence_number: 0,
message_item_id: format!("msg_{}", Uuid::new_v4().simple()),
message_started: false,
message_output_index: 0,
accumulated_text: String::new(),
function_call_items: Vec::new(),
next_output_index: 0,
}
}
fn next_seq(&mut self) -> u64 {
let seq = self.sequence_number;
self.sequence_number += 1;
seq
}
fn make_response(&self, status: Status, output: Vec<OutputItem>) -> Response {
let completed_at = if status == Status::Completed {
Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)
} else {
None
};
Response {
id: self.response_id.clone(),
object: "response".to_string(),
created_at: self.created_at,
completed_at,
status,
model: self.model.clone(),
output,
// Echo request params with spec-required defaults for omitted fields
background: Some(false),
frequency_penalty: Some(0.0),
metadata: Some(serde_json::Value::Object(Default::default())),
parallel_tool_calls: Some(true),
presence_penalty: Some(0.0),
// store: false because this branch does not persist responses.
store: self.params.store.or(Some(false)),
temperature: self.params.temperature.or(Some(1.0)),
text: Some(ResponseTextParam {
format: TextResponseFormatConfiguration::Text,
verbosity: None,
}),
tool_choice: self
.params
.tool_choice
.clone()
.or(Some(ToolChoiceParam::Mode(ToolChoiceOptions::Auto))),
tools: Some(
self.params
.tools
.clone()
.map(super::normalize_tools)
.unwrap_or_default(),
),
top_p: self.params.top_p.or(Some(1.0)),
truncation: Some(Truncation::Disabled),
// Nullable required fields
billing: None,
conversation: None,
error: None,
incomplete_details: None,
instructions: self.params.instructions.clone().map(Instructions::Text),
max_output_tokens: self.params.max_output_tokens,
max_tool_calls: None,
previous_response_id: None,
prompt: None,
prompt_cache_key: None,
prompt_cache_retention: None,
reasoning: None,
safety_identifier: None,
service_tier: Some(ServiceTier::Auto),
top_logprobs: Some(0),
usage: None,
}
}
/// Emit the initial lifecycle events: created + in_progress.
pub fn emit_start_events(&mut self) -> Vec<Result<Event, anyhow::Error>> {
let mut events = Vec::with_capacity(2);
let created = ResponseStreamEvent::ResponseCreated(ResponseCreatedEvent {
sequence_number: self.next_seq(),
response: self.make_response(Status::InProgress, vec![]),
});
events.push(make_sse_event(&created));
let in_progress = ResponseStreamEvent::ResponseInProgress(ResponseInProgressEvent {
sequence_number: self.next_seq(),
response: self.make_response(Status::InProgress, vec![]),
});
events.push(make_sse_event(&in_progress));
events
}
/// Process a single chat completion stream chunk and return zero or more SSE events.
pub fn process_chunk(
&mut self,
chunk: &NvCreateChatCompletionStreamResponse,
) -> Vec<Result<Event, anyhow::Error>> {
let mut events = Vec::new();
for choice in &chunk.choices {
let delta = &choice.delta;
// Handle text content deltas — extract text from the enum
let content_text = match &delta.content {
Some(ChatCompletionMessageContent::Text(text)) => Some(text.as_str()),
Some(ChatCompletionMessageContent::Parts(_)) => {
// Multimodal streaming not yet supported
None
}
None => None,
};
if let Some(content) = content_text
&& !content.is_empty()
{
// Emit output_item.added + content_part.added on first text
if !self.message_started {
self.message_started = true;
self.message_output_index = self.next_output_index;
let output_index = self.message_output_index;
self.next_output_index += 1;
let item_added = ResponseStreamEvent::ResponseOutputItemAdded(
ResponseOutputItemAddedEvent {
sequence_number: self.next_seq(),
output_index,
item: OutputItem::Message(OutputMessage {
id: self.message_item_id.clone(),
content: vec![],
role: AssistantRole::Assistant,
status: OutputStatus::InProgress,
}),
},
);
events.push(make_sse_event(&item_added));
let part_added = ResponseStreamEvent::ResponseContentPartAdded(
ResponseContentPartAddedEvent {
sequence_number: self.next_seq(),
item_id: self.message_item_id.clone(),
output_index,
content_index: 0,
part: OutputContent::OutputText(OutputTextContent {
text: String::new(),
annotations: vec![],
logprobs: Some(vec![]),
}),
},
);
events.push(make_sse_event(&part_added));
}
// Emit text delta
self.accumulated_text.push_str(content);
let text_delta =
ResponseStreamEvent::ResponseOutputTextDelta(ResponseTextDeltaEvent {
sequence_number: self.next_seq(),
item_id: self.message_item_id.clone(),
output_index: self.message_output_index,
content_index: 0,
delta: content.to_string(),
logprobs: Some(vec![]),
});
events.push(make_sse_event(&text_delta));
}
// Handle tool call deltas
if let Some(tool_calls) = &delta.tool_calls {
for tc in tool_calls {
let tc_index = tc.index as usize;
// Start a new function call if we haven't seen this index
while self.function_call_items.len() <= tc_index {
let output_index = self.next_output_index;
self.next_output_index += 1;
self.function_call_items.push(FunctionCallState {
item_id: format!("fc_{}", Uuid::new_v4().simple()),
call_id: String::new(),
name: String::new(),
accumulated_args: String::new(),
output_index,
started: false,
});
}
// Update call_id and name if provided
if let Some(id) = &tc.id {
self.function_call_items[tc_index].call_id = id.clone();
}
if let Some(func) = &tc.function {
if let Some(name) = &func.name {
self.function_call_items[tc_index].name = name.clone();
}
if let Some(args) = &func.arguments {
// Emit output_item.added on first delta for this function call
if !self.function_call_items[tc_index].started {
self.function_call_items[tc_index].started = true;
let item_id = self.function_call_items[tc_index].item_id.clone();
let call_id = self.function_call_items[tc_index].call_id.clone();
let fc_name = self.function_call_items[tc_index].name.clone();
let output_index = self.function_call_items[tc_index].output_index;
let seq = self.next_seq();
let item_added = ResponseStreamEvent::ResponseOutputItemAdded(
ResponseOutputItemAddedEvent {
sequence_number: seq,
output_index,
item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(item_id),
call_id,
name: fc_name,
arguments: String::new(),
status: Some(OutputStatus::InProgress),
}),
},
);
events.push(make_sse_event(&item_added));
}
self.function_call_items[tc_index]
.accumulated_args
.push_str(args);
let item_id = self.function_call_items[tc_index].item_id.clone();
let output_index = self.function_call_items[tc_index].output_index;
let seq = self.next_seq();
let args_delta =
ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(
ResponseFunctionCallArgumentsDeltaEvent {
sequence_number: seq,
item_id,
output_index,
delta: args.clone(),
},
);
events.push(make_sse_event(&args_delta));
}
}
}
}
}
events
}
/// Emit the final events when the stream ends: done events + completed.
pub fn emit_end_events(&mut self) -> Vec<Result<Event, anyhow::Error>> {
let mut events = Vec::new();
// Close text message if it was started
if self.message_started {
let text_done = ResponseStreamEvent::ResponseOutputTextDone(ResponseTextDoneEvent {
sequence_number: self.next_seq(),
item_id: self.message_item_id.clone(),
output_index: self.message_output_index,
content_index: 0,
text: self.accumulated_text.clone(),
logprobs: Some(vec![]),
});
events.push(make_sse_event(&text_done));
let part_done =
ResponseStreamEvent::ResponseContentPartDone(ResponseContentPartDoneEvent {
sequence_number: self.next_seq(),
item_id: self.message_item_id.clone(),
output_index: self.message_output_index,
content_index: 0,
part: OutputContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
}),
});
events.push(make_sse_event(&part_done));
let item_done =
ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
sequence_number: self.next_seq(),
output_index: self.message_output_index,
item: OutputItem::Message(OutputMessage {
id: self.message_item_id.clone(),
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
})],
role: AssistantRole::Assistant,
status: OutputStatus::Completed,
}),
});
events.push(make_sse_event(&item_done));
}
// Close any function call items - collect data first to avoid borrow conflicts
let fc_data: Vec<_> = self
.function_call_items
.iter()
.filter(|fc| fc.started)
.map(|fc| {
(
fc.item_id.clone(),
fc.call_id.clone(),
fc.name.clone(),
fc.output_index,
fc.accumulated_args.clone(),
)
})
.collect();
for (item_id, call_id, fc_name, output_index, accumulated_args) in fc_data {
let args_done = ResponseStreamEvent::ResponseFunctionCallArgumentsDone(
ResponseFunctionCallArgumentsDoneEvent {
sequence_number: self.next_seq(),
item_id: item_id.clone(),
output_index,
arguments: accumulated_args.clone(),
name: Some(fc_name.clone()),
},
);
events.push(make_sse_event(&args_done));
let item_done =
ResponseStreamEvent::ResponseOutputItemDone(ResponseOutputItemDoneEvent {
sequence_number: self.next_seq(),
output_index,
item: OutputItem::FunctionCall(FunctionToolCall {
id: Some(item_id),
call_id,
name: fc_name,
arguments: accumulated_args,
status: Some(OutputStatus::Completed),
}),
});
events.push(make_sse_event(&item_done));
}
// Build the final output vector from accumulated state
let mut output = Vec::new();
if self.message_started {
output.push(OutputItem::Message(OutputMessage {
id: self.message_item_id.clone(),
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
})],
role: AssistantRole::Assistant,
status: OutputStatus::Completed,
}));
}
for fc in &self.function_call_items {
if fc.started {
output.push(OutputItem::FunctionCall(FunctionToolCall {
id: Some(fc.item_id.clone()),
call_id: fc.call_id.clone(),
name: fc.name.clone(),
arguments: fc.accumulated_args.clone(),
status: Some(OutputStatus::Completed),
}));
}
}
// Emit response.completed
let completed = ResponseStreamEvent::ResponseCompleted(ResponseCompletedEvent {
sequence_number: self.next_seq(),
response: self.make_response(Status::Completed, output),
});
events.push(make_sse_event(&completed));
events
}
/// Emit error events when the stream ends due to a backend error.
pub fn emit_error_events(&mut self) -> Vec<Result<Event, anyhow::Error>> {
let mut events = Vec::new();
let failed = ResponseStreamEvent::ResponseFailed(ResponseFailedEvent {
sequence_number: self.next_seq(),
response: self.make_response(Status::Failed, vec![]),
});
events.push(make_sse_event(&failed));
events
}
}
fn make_sse_event(event: &ResponseStreamEvent) -> Result<Event, anyhow::Error> {
let event_type = get_event_type(event);
let data = serde_json::to_string(event)?;
Ok(Event::default().event(event_type).data(data))
}
fn get_event_type(event: &ResponseStreamEvent) -> &'static str {
match event {
ResponseStreamEvent::ResponseCreated(_) => "response.created",
ResponseStreamEvent::ResponseInProgress(_) => "response.in_progress",
ResponseStreamEvent::ResponseCompleted(_) => "response.completed",
ResponseStreamEvent::ResponseFailed(_) => "response.failed",
ResponseStreamEvent::ResponseIncomplete(_) => "response.incomplete",
ResponseStreamEvent::ResponseQueued(_) => "response.queued",
ResponseStreamEvent::ResponseOutputItemAdded(_) => "response.output_item.added",
ResponseStreamEvent::ResponseOutputItemDone(_) => "response.output_item.done",
ResponseStreamEvent::ResponseContentPartAdded(_) => "response.content_part.added",
ResponseStreamEvent::ResponseContentPartDone(_) => "response.content_part.done",
ResponseStreamEvent::ResponseOutputTextDelta(_) => "response.output_text.delta",
ResponseStreamEvent::ResponseOutputTextDone(_) => "response.output_text.done",
ResponseStreamEvent::ResponseRefusalDelta(_) => "response.refusal.delta",
ResponseStreamEvent::ResponseRefusalDone(_) => "response.refusal.done",
ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(_) => {
"response.function_call_arguments.delta"
}
ResponseStreamEvent::ResponseFunctionCallArgumentsDone(_) => {
"response.function_call_arguments.done"
}
ResponseStreamEvent::ResponseFileSearchCallInProgress(_) => {
"response.file_search_call.in_progress"
}
ResponseStreamEvent::ResponseFileSearchCallSearching(_) => {
"response.file_search_call.searching"
}
ResponseStreamEvent::ResponseFileSearchCallCompleted(_) => {
"response.file_search_call.completed"
}
ResponseStreamEvent::ResponseWebSearchCallInProgress(_) => {
"response.web_search_call.in_progress"
}
ResponseStreamEvent::ResponseWebSearchCallSearching(_) => {
"response.web_search_call.searching"
}
ResponseStreamEvent::ResponseWebSearchCallCompleted(_) => {
"response.web_search_call.completed"
}
ResponseStreamEvent::ResponseReasoningSummaryPartAdded(_) => {
"response.reasoning_summary_part.added"
}
ResponseStreamEvent::ResponseReasoningSummaryPartDone(_) => {
"response.reasoning_summary_part.done"
}
ResponseStreamEvent::ResponseReasoningSummaryTextDelta(_) => {
"response.reasoning_summary_text.delta"
}
ResponseStreamEvent::ResponseReasoningSummaryTextDone(_) => {
"response.reasoning_summary_text.done"
}
ResponseStreamEvent::ResponseReasoningTextDelta(_) => "response.reasoning_text.delta",
ResponseStreamEvent::ResponseReasoningTextDone(_) => "response.reasoning_text.done",
ResponseStreamEvent::ResponseImageGenerationCallCompleted(_) => {
"response.image_generation_call.completed"
}
ResponseStreamEvent::ResponseImageGenerationCallGenerating(_) => {
"response.image_generation_call.generating"
}
ResponseStreamEvent::ResponseImageGenerationCallInProgress(_) => {
"response.image_generation_call.in_progress"
}
ResponseStreamEvent::ResponseImageGenerationCallPartialImage(_) => {
"response.image_generation_call.partial_image"
}
ResponseStreamEvent::ResponseMCPCallArgumentsDelta(_) => {
"response.mcp_call_arguments.delta"
}
ResponseStreamEvent::ResponseMCPCallArgumentsDone(_) => "response.mcp_call_arguments.done",
ResponseStreamEvent::ResponseMCPCallCompleted(_) => "response.mcp_call.completed",
ResponseStreamEvent::ResponseMCPCallFailed(_) => "response.mcp_call.failed",
ResponseStreamEvent::ResponseMCPCallInProgress(_) => "response.mcp_call.in_progress",
ResponseStreamEvent::ResponseMCPListToolsCompleted(_) => {
"response.mcp_list_tools.completed"
}
ResponseStreamEvent::ResponseMCPListToolsFailed(_) => "response.mcp_list_tools.failed",
ResponseStreamEvent::ResponseMCPListToolsInProgress(_) => {
"response.mcp_list_tools.in_progress"
}
ResponseStreamEvent::ResponseCodeInterpreterCallInProgress(_) => {
"response.code_interpreter_call.in_progress"
}
ResponseStreamEvent::ResponseCodeInterpreterCallInterpreting(_) => {
"response.code_interpreter_call.interpreting"
}
ResponseStreamEvent::ResponseCodeInterpreterCallCompleted(_) => {
"response.code_interpreter_call.completed"
}
ResponseStreamEvent::ResponseCodeInterpreterCallCodeDelta(_) => {
"response.code_interpreter_call_code.delta"
}
ResponseStreamEvent::ResponseCodeInterpreterCallCodeDone(_) => {
"response.code_interpreter_call_code.done"
}
ResponseStreamEvent::ResponseOutputTextAnnotationAdded(_) => {
"response.output_text.annotation.added"
}
ResponseStreamEvent::ResponseCustomToolCallInputDelta(_) => {
"response.custom_tool_call_input.delta"
}
ResponseStreamEvent::ResponseCustomToolCallInputDone(_) => {
"response.custom_tool_call_input.done"
}
ResponseStreamEvent::ResponseError(_) => "error",
}
}
...@@ -156,6 +156,7 @@ def run_serve_deployment( ...@@ -156,6 +156,7 @@ def run_serve_deployment(
payload=payload.body, payload=payload.body,
timeout=payload.timeout, timeout=payload.timeout,
method=payload.method, method=payload.method,
stream=payload.http_stream,
) )
server_process.check_response(payload, response) server_process.check_response(payload, response)
......
...@@ -23,6 +23,8 @@ from tests.utils.payload_builder import ( ...@@ -23,6 +23,8 @@ from tests.utils.payload_builder import (
embedding_payload, embedding_payload,
embedding_payload_default, embedding_payload_default,
metric_payload_default, metric_payload_default,
responses_payload_default,
responses_stream_payload_default,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -60,6 +62,8 @@ sglang_configs = { ...@@ -60,6 +62,8 @@ sglang_configs = {
request_payloads=[ request_payloads=[
chat_payload_default(), chat_payload_default(),
completion_payload_default(), completion_payload_default(),
responses_payload_default(),
responses_stream_payload_default(),
metric_payload_default(min_num_requests=6, backend="sglang"), metric_payload_default(min_num_requests=6, backend="sglang"),
], ],
), ),
......
...@@ -52,6 +52,7 @@ def send_request( ...@@ -52,6 +52,7 @@ def send_request(
timeout: float = 30.0, timeout: float = 30.0,
method: str = "POST", method: str = "POST",
log_level: int = 20, log_level: int = 20,
stream: bool = False,
) -> requests.Response: ) -> requests.Response:
""" """
Send an HTTP request to the engine with detailed logging. Send an HTTP request to the engine with detailed logging.
...@@ -97,7 +98,7 @@ def send_request( ...@@ -97,7 +98,7 @@ def send_request(
if method_upper == "GET": if method_upper == "GET":
response = requests.get(url, params=payload, timeout=timeout) response = requests.get(url, params=payload, timeout=timeout)
elif method_upper == "POST": elif method_upper == "POST":
response = requests.post(url, json=payload, timeout=timeout) response = requests.post(url, json=payload, timeout=timeout, stream=stream)
else: else:
# Fallback for other methods if needed # Fallback for other methods if needed
response = requests.request( response = requests.request(
...@@ -117,20 +118,26 @@ def send_request( ...@@ -117,20 +118,26 @@ def send_request(
logger.debug("Response headers: %s", dict(response.headers)) logger.debug("Response headers: %s", dict(response.headers))
# Try to log response body (truncated if too long) # Try to log response body (truncated if too long)
try: # Skip body logging for streaming responses to avoid consuming the stream
if response.headers.get("content-type", "").startswith("application/json"): if stream:
response_data = response.json() logger.debug("Response body: <streaming, not logged>")
response_str = json.dumps(response_data, indent=2) else:
if len(response_str) > 1000: try:
response_str = response_str[:1000] + "... (truncated)" if response.headers.get("content-type", "").startswith(
logger.debug("Response body: %s", response_str) "application/json"
else: ):
response_text = response.text response_data = response.json()
if len(response_text) > 1000: response_str = json.dumps(response_data, indent=2)
response_text = response_text[:1000] + "... (truncated)" if len(response_str) > 1000:
logger.debug("Response body: %s", response_text) response_str = response_str[:1000] + "... (truncated)"
except Exception as e: logger.debug("Response body: %s", response_str)
logger.debug("Could not parse response body: %s", e) else:
response_text = response.text
if len(response_text) > 1000:
response_text = response_text[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_text)
except Exception as e:
logger.debug("Could not parse response body: %s", e)
return response return response
......
...@@ -14,6 +14,8 @@ from tests.utils.payloads import ( ...@@ -14,6 +14,8 @@ from tests.utils.payloads import (
EmbeddingPayload, EmbeddingPayload,
LMCacheMetricsPayload, LMCacheMetricsPayload,
MetricsPayload, MetricsPayload,
ResponsesPayload,
ResponsesStreamPayload,
SGLangMetricsPayload, SGLangMetricsPayload,
TRTLLMMetricsPayload, TRTLLMMetricsPayload,
VLLMMetricsPayload, VLLMMetricsPayload,
...@@ -480,3 +482,52 @@ def completion_payload_with_logprobs( ...@@ -480,3 +482,52 @@ def completion_payload_with_logprobs(
expected_log=[], expected_log=[],
expected_response=expected_response or ["AI", "knock", "joke"], expected_response=expected_response or ["AI", "knock", "joke"],
) )
def responses_payload_default(
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 200,
temperature: float = 0.0,
) -> ResponsesPayload:
"""Create a default Responses API payload (non-streaming).
For full compliance testing, use the OpenResponses bun CLI instead.
"""
return ResponsesPayload(
body={
"input": TEXT_PROMPT,
"max_output_tokens": max_tokens,
"temperature": temperature,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or ["AI", "knock", "joke", "think", "artificial", "intelligence"],
)
def responses_stream_payload_default(
repeat_count: int = 1,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
max_tokens: int = 200,
temperature: float = 0.0,
) -> ResponsesStreamPayload:
"""Create a default Responses API streaming payload.
For full compliance testing, use the OpenResponses bun CLI instead.
"""
return ResponsesStreamPayload(
body={
"input": TEXT_PROMPT,
"stream": True,
"max_output_tokens": max_tokens,
"temperature": temperature,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or ["AI", "knock", "joke", "think", "artificial", "intelligence"],
)
...@@ -47,6 +47,8 @@ class BasePayload: ...@@ -47,6 +47,8 @@ class BasePayload:
# Optional additional ports used by specialized payloads (e.g. LoRA system/control-plane APIs). # Optional additional ports used by specialized payloads (e.g. LoRA system/control-plane APIs).
# This is intentionally empty by default to preserve prior semantics. # This is intentionally empty by default to preserve prior semantics.
system_ports: list[int] = field(default_factory=list) system_ports: list[int] = field(default_factory=list)
# When True, the HTTP request is made with stream=True (for SSE responses).
http_stream: bool = False
def url(self) -> str: def url(self) -> str:
ep = self.endpoint.lstrip("/") ep = self.endpoint.lstrip("/")
...@@ -485,6 +487,130 @@ class CompletionPayloadWithLogprobs(CompletionPayload): ...@@ -485,6 +487,130 @@ class CompletionPayloadWithLogprobs(CompletionPayload):
) )
@dataclass
class ResponsesPayload(BasePayload):
"""Payload for the Responses API endpoint (/v1/responses).
For full compliance testing, use the OpenResponses bun CLI:
bun run test:compliance --base-url http://localhost:<port>/v1 --api-key test --model <model>
See https://www.openresponses.org/compliance
"""
endpoint: str = "/v1/responses"
@staticmethod
def extract_content(response):
"""Extract text content from a Responses API response."""
response.raise_for_status()
result = response.json()
assert (
result.get("object") == "response"
), f"Expected object='response', got {result.get('object')}"
assert result.get("id", "").startswith(
"resp_"
), f"Expected id to start with 'resp_', got {result.get('id')}"
assert (
result.get("status") == "completed"
), f"Expected status='completed', got {result.get('status')}"
output = result.get("output", [])
assert len(output) > 0, "Response output is empty"
msg = output[0]
assert (
msg.get("type") == "message"
), f"Expected output[0].type='message', got {msg.get('type')}"
assert (
msg.get("role") == "assistant"
), f"Expected role='assistant', got {msg.get('role')}"
content_parts = msg.get("content", [])
assert len(content_parts) > 0, "Message content is empty"
assert (
content_parts[0].get("type") == "output_text"
), f"Expected content[0].type='output_text', got {content_parts[0].get('type')}"
return content_parts[0].get("text", "")
def response_handler(self, response: Any) -> str:
return ResponsesPayload.extract_content(response)
@dataclass
class ResponsesStreamPayload(BasePayload):
"""Streaming payload for the Responses API endpoint (/v1/responses).
Validates SSE event structure and lifecycle ordering.
"""
endpoint: str = "/v1/responses"
http_stream: bool = True
@staticmethod
def extract_content(response):
"""Parse SSE stream and validate event structure."""
import json
response.raise_for_status()
events = []
event_type = ""
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("event: "):
event_type = line[len("event: ") :]
elif line.startswith("data: "):
data_str = line[len("data: ") :]
if data_str == "[DONE]":
events.append(("done", None))
else:
events.append((event_type, json.loads(data_str)))
event_types = [e[0] for e in events]
# Validate lifecycle event ordering
assert len(event_types) >= 2, f"Too few events: {event_types}"
assert (
event_types[0] == "response.created"
), f"First event should be response.created, got {event_types[0]}"
assert (
event_types[1] == "response.in_progress"
), f"Second event should be response.in_progress, got {event_types[1]}"
non_done = [e for e in event_types if e != "done"]
assert (
non_done[-1] == "response.completed"
), f"Last real event should be response.completed, got {non_done[-1]}"
# Validate text content events
assert "response.output_item.added" in event_types, "Missing output_item.added"
assert (
"response.content_part.added" in event_types
), "Missing content_part.added"
assert "response.output_text.delta" in event_types, "Missing output_text.delta"
assert "response.output_text.done" in event_types, "Missing output_text.done"
assert "response.content_part.done" in event_types, "Missing content_part.done"
assert "response.output_item.done" in event_types, "Missing output_item.done"
# Verify text deltas concatenate to the final text
deltas = [e[1]["delta"] for e in events if e[0] == "response.output_text.delta"]
done_events = [e for e in events if e[0] == "response.output_text.done"]
assert (
len(done_events) == 1
), f"Expected 1 output_text.done, got {len(done_events)}"
full_text = "".join(deltas)
assert (
done_events[0][1]["text"] == full_text
), "Concatenated deltas don't match output_text.done text"
return full_text
def response_handler(self, response: Any) -> str:
return ResponsesStreamPayload.extract_content(response)
@dataclass @dataclass
class EmbeddingPayload(BasePayload): class EmbeddingPayload(BasePayload):
"""Payload for embeddings endpoint.""" """Payload for embeddings endpoint."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment