Unverified Commit 3bfee568 authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

feat: unified internal request representation for lossless API conversion (#7202)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
Signed-off-by: default avatarMarko Kosec <mkosec@nvidia.com>
parent 8fe2082c
...@@ -40,9 +40,10 @@ use crate::protocols::anthropic::types::{ ...@@ -40,9 +40,10 @@ use crate::protocols::anthropic::types::{
chat_completion_to_anthropic_response, chat_completion_to_anthropic_response,
}; };
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse, NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse,
NvCreateChatCompletionStreamResponse, aggregator::ChatCompletionAggregator, aggregator::ChatCompletionAggregator,
}; };
use crate::protocols::unified::UnifiedRequest;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::Annotated; use crate::types::Annotated;
...@@ -213,20 +214,25 @@ async fn anthropic_messages( ...@@ -213,20 +214,25 @@ async fn anthropic_messages(
.as_ref() .as_ref()
.is_some_and(|t| t.thinking_type == "disabled"); .is_some_and(|t| t.thinking_type == "disabled");
// Convert Anthropic request -> Chat Completion request // Convert Anthropic request -> UnifiedRequest -> Chat Completion request
let mut chat_request: NvCreateChatCompletionRequest = let unified_request: UnifiedRequest = orig_request.try_into().map_err(|e: anyhow::Error| {
orig_request.try_into().map_err(|e: anyhow::Error| { tracing::error!(
tracing::error!( request_id,
request_id, error = %e,
error = %e, "Failed to convert AnthropicCreateMessageRequest to UnifiedRequest",
"Failed to convert AnthropicCreateMessageRequest to NvCreateChatCompletionRequest", );
); anthropic_error(
anthropic_error( StatusCode::BAD_REQUEST,
StatusCode::BAD_REQUEST, "invalid_request_error",
"invalid_request_error", &format!("Failed to convert request: {}", e),
&format!("Failed to convert request: {}", e), )
) })?;
})?;
// Extract the API context before consuming the UnifiedRequest — this
// carries Anthropic-specific fields (thinking config, cache breakpoints,
// etc.) that the stream converter needs for faithful response reconstruction.
let anthropic_ctx = unified_request.anthropic_context().cloned();
let mut chat_request = unified_request.into_inner();
// When a reasoning parser is configured and the client hasn't explicitly // When a reasoning parser is configured and the client hasn't explicitly
// disabled thinking, assume the model's chat template will inject `<think>`. // disabled thinking, assume the model's chat template will inject `<think>`.
...@@ -309,7 +315,10 @@ async fn anthropic_messages( ...@@ -309,7 +315,10 @@ async fn anthropic_messages(
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
let mut converter = AnthropicStreamConverter::new(model_for_resp); let mut converter = match anthropic_ctx {
Some(ctx) => AnthropicStreamConverter::with_context(model_for_resp, ctx),
None => AnthropicStreamConverter::new(model_for_resp),
};
let start_events = converter.emit_start_events(); let start_events = converter.emit_start_events();
let converter = std::sync::Arc::new(std::sync::Mutex::new(converter)); let converter = std::sync::Arc::new(std::sync::Mutex::new(converter));
...@@ -406,7 +415,11 @@ async fn anthropic_messages( ...@@ -406,7 +415,11 @@ async fn anthropic_messages(
) )
})?; })?;
let response = chat_completion_to_anthropic_response(chat_response, &model_for_resp); let response = chat_completion_to_anthropic_response(
chat_response,
&model_for_resp,
anthropic_ctx.as_ref(),
);
inflight_guard.mark_ok(); inflight_guard.mark_ok();
......
...@@ -57,6 +57,7 @@ use crate::protocols::openai::{ ...@@ -57,6 +57,7 @@ use crate::protocols::openai::{
responses::{NvCreateResponse, NvResponse, ResponseParams, chat_completion_to_response}, responses::{NvCreateResponse, NvResponse, ResponseParams, chat_completion_to_response},
videos::{NvCreateVideoRequest, NvVideosResponse}, videos::{NvCreateVideoRequest, NvVideosResponse},
}; };
use crate::protocols::unified::UnifiedRequest;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::Annotated; use crate::types::Annotated;
use dynamo_runtime::logging::get_distributed_tracing_context; use dynamo_runtime::logging::get_distributed_tracing_context;
...@@ -1513,21 +1514,25 @@ async fn responses( ...@@ -1513,21 +1514,25 @@ async fn responses(
let request_id = request.id().to_string(); let request_id = request.id().to_string();
let (orig_request, context) = request.into_parts(); let (orig_request, context) = request.into_parts();
let mut chat_request: NvCreateChatCompletionRequest = let unified_request: UnifiedRequest = orig_request.try_into().map_err(|e: anyhow::Error| {
orig_request.try_into().map_err(|e: anyhow::Error| { tracing::error!(
tracing::error!( request_id,
request_id, error = %e,
error = %e, "Failed to convert NvCreateResponse to UnifiedRequest",
"Failed to convert NvCreateResponse to NvCreateChatCompletionRequest", );
); let err_response = ErrorMessage::not_implemented_error(
let err_response = ErrorMessage::not_implemented_error( VALIDATION_PREFIX.to_string()
VALIDATION_PREFIX.to_string() + "Failed to convert responses request: "
+ "Failed to convert responses request: " + &e.to_string(),
+ &e.to_string(), );
); inflight_guard.mark_error(extract_error_type_from_response(&err_response));
inflight_guard.mark_error(extract_error_type_from_response(&err_response)); err_response
err_response })?;
})?; // Extract the API context before consuming the UnifiedRequest — this
// carries Responses-specific fields (previous_response_id, store, etc.)
// that the stream converter needs for faithful response reconstruction.
let responses_ctx = unified_request.responses_context().cloned();
let mut chat_request = unified_request.into_inner();
// Always use internal streaming for aggregation. // Always use internal streaming for aggregation.
// Set stream_options.include_usage so the backend sends token counts in the final chunk. // Set stream_options.include_usage so the backend sends token counts in the final chunk.
...@@ -1577,7 +1582,10 @@ async fn responses( ...@@ -1577,7 +1582,10 @@ async fn responses(
use crate::protocols::openai::responses::stream_converter::ResponseStreamConverter; use crate::protocols::openai::responses::stream_converter::ResponseStreamConverter;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
let mut converter = ResponseStreamConverter::new(model.clone(), response_params); let mut converter = match responses_ctx {
Some(ctx) => ResponseStreamConverter::with_context(model.clone(), response_params, ctx),
None => ResponseStreamConverter::new(model.clone(), response_params),
};
let start_events = converter.emit_start_events(); let start_events = converter.emit_start_events();
// Use std::sync::Mutex (not tokio) since process_chunk/emit_end_events are // Use std::sync::Mutex (not tokio) since process_chunk/emit_end_events are
...@@ -1685,18 +1693,19 @@ async fn responses( ...@@ -1685,18 +1693,19 @@ async fn responses(
})?; })?;
// Convert NvCreateChatCompletionResponse --> NvResponse // Convert NvCreateChatCompletionResponse --> NvResponse
let response: NvResponse = chat_completion_to_response(response, &response_params) let response: NvResponse =
.map_err(|e| { chat_completion_to_response(response, &response_params, responses_ctx.as_ref())
tracing::error!( .map_err(|e| {
request_id, tracing::error!(
"Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}", request_id,
e "Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}",
); e
let err_response = );
ErrorMessage::internal_server_error("Failed to convert internal response"); let err_response =
inflight_guard.mark_error(extract_error_type_from_response(&err_response)); ErrorMessage::internal_server_error("Failed to convert internal response");
err_response inflight_guard.mark_error(extract_error_type_from_response(&err_response));
})?; err_response
})?;
inflight_guard.mark_ok(); inflight_guard.mark_ok();
// If the engine context was killed (client disconnect), the response was // If the engine context was killed (client disconnect), the response was
......
...@@ -15,6 +15,7 @@ pub mod codec; ...@@ -15,6 +15,7 @@ pub mod codec;
pub mod common; pub mod common;
pub mod openai; pub mod openai;
pub mod tensor; pub mod tensor;
pub(crate) mod unified;
/// The token ID type /// The token ID type
pub type TokenIdType = u32; pub type TokenIdType = u32;
......
...@@ -18,11 +18,14 @@ use super::types::{ ...@@ -18,11 +18,14 @@ use super::types::{
AnthropicResponseContentBlock, AnthropicStopReason, AnthropicStreamEvent, AnthropicUsage, AnthropicResponseContentBlock, AnthropicStopReason, AnthropicStreamEvent, AnthropicUsage,
}; };
use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use crate::protocols::unified::AnthropicContext;
/// State machine that converts a chat completion stream into Anthropic SSE events. /// State machine that converts a chat completion stream into Anthropic SSE events.
pub struct AnthropicStreamConverter { pub struct AnthropicStreamConverter {
model: String, model: String,
message_id: String, message_id: String,
/// Preserved Anthropic-specific request context for faithful response reconstruction.
api_context: Option<AnthropicContext>,
// Thinking/reasoning tracking // Thinking/reasoning tracking
thinking_block_started: bool, thinking_block_started: bool,
thinking_block_closed: bool, thinking_block_closed: bool,
...@@ -60,6 +63,7 @@ impl AnthropicStreamConverter { ...@@ -60,6 +63,7 @@ impl AnthropicStreamConverter {
Self { Self {
model, model,
message_id: format!("msg_{}", Uuid::new_v4().simple()), message_id: format!("msg_{}", Uuid::new_v4().simple()),
api_context: None,
thinking_block_started: false, thinking_block_started: false,
thinking_block_closed: false, thinking_block_closed: false,
thinking_block_index: 0, thinking_block_index: 0,
...@@ -76,8 +80,19 @@ impl AnthropicStreamConverter { ...@@ -76,8 +80,19 @@ impl AnthropicStreamConverter {
} }
} }
/// Create a converter seeded with the original Anthropic request context.
/// This allows the response stream to carry forward metadata that was lost
/// during the Anthropic-to-OpenAI request conversion.
pub fn with_context(model: String, context: AnthropicContext) -> Self {
let mut converter = Self::new(model);
converter.api_context = Some(context);
converter
}
/// Emit the initial `message_start` event. /// Emit the initial `message_start` event.
pub fn emit_start_events(&mut self) -> Vec<Result<Event, anyhow::Error>> { pub fn emit_start_events(&mut self) -> Vec<Result<Event, anyhow::Error>> {
// TODO: When AnthropicMessageResponse gains a `service_tier` field,
// populate it from `self.api_context` (if the original request specified one).
let message = AnthropicMessageResponse { let message = AnthropicMessageResponse {
id: self.message_id.clone(), id: self.message_id.clone(),
object_type: "message".to_string(), object_type: "message".to_string(),
...@@ -182,6 +197,11 @@ impl AnthropicStreamConverter { ...@@ -182,6 +197,11 @@ impl AnthropicStreamConverter {
// Emit signature delta to close the thinking block. // Emit signature delta to close the thinking block.
// The engine doesn't produce Anthropic-style cryptographic signatures, // The engine doesn't produce Anthropic-style cryptographic signatures,
// so we use "erased" (the standard placeholder per the Anthropic spec). // so we use "erased" (the standard placeholder per the Anthropic spec).
// When `api_context` is available and the original request had
// `thinking.thinking_type == "enabled"`, this is expected — the backend
// simply doesn't generate real signatures. If/when the backend starts
// returning real signatures, we can use the context to validate or
// pass them through instead of hardcoding "erased".
let sig_delta = AnthropicStreamEvent::ContentBlockDelta { let sig_delta = AnthropicStreamEvent::ContentBlockDelta {
index: self.thinking_block_index, index: self.thinking_block_index,
delta: AnthropicDelta::SignatureDelta { delta: AnthropicDelta::SignatureDelta {
...@@ -1071,4 +1091,35 @@ mod tests { ...@@ -1071,4 +1091,35 @@ mod tests {
"no block stops in end events" "no block stops in end events"
); );
} }
/// Verify that `with_context` stores the context and produces the same
/// event structure as `new` — the context is carried for future enrichment.
#[test]
fn test_with_context_preserves_context() {
use crate::protocols::unified::AnthropicContext;
let ctx = AnthropicContext {
service_tier: Some("priority".to_string()),
..Default::default()
};
let mut conv = AnthropicStreamConverter::with_context("test-model".into(), ctx);
assert!(conv.api_context.is_some());
assert_eq!(
conv.api_context.as_ref().unwrap().service_tier.as_deref(),
Some("priority")
);
// Should produce the same events as a regular converter
let ev = conv.process_chunk_tagged(&text_chunk("Hello"));
assert_eq!(
event_types(&ev),
vec!["content_block_start", "content_block_delta"]
);
let end = conv.emit_end_events_tagged();
assert_eq!(
event_types(&end),
vec!["content_block_stop", "message_delta", "message_stop"]
);
}
} }
...@@ -120,7 +120,10 @@ impl TryFrom<AnthropicCreateMessageRequest> for NvCreateChatCompletionRequest { ...@@ -120,7 +120,10 @@ impl TryFrom<AnthropicCreateMessageRequest> for NvCreateChatCompletionRequest {
..Default::default() ..Default::default()
}, },
nvext: { nvext: {
// Collect per-block cache_control: use the last one found // Lossy: collapse all per-block cache_control into a single
// last-one-wins value. Sufficient for backends with a single
// prefix cache boundary. Full per-block breakpoints are
// preserved in AnthropicContext::cache_breakpoints via UnifiedRequest.
let mut last_block_cc: Option<CacheControl> = None; let mut last_block_cc: Option<CacheControl> = None;
for msg in &req.messages { for msg in &req.messages {
if let AnthropicMessageContent::Blocks { content } = &msg.content { if let AnthropicMessageContent::Blocks { content } = &msg.content {
...@@ -472,7 +475,9 @@ fn convert_anthropic_tool_choice(tc: &AnthropicToolChoice) -> ChatCompletionTool ...@@ -472,7 +475,9 @@ fn convert_anthropic_tool_choice(tc: &AnthropicToolChoice) -> ChatCompletionTool
pub fn chat_completion_to_anthropic_response( pub fn chat_completion_to_anthropic_response(
chat_resp: NvCreateChatCompletionResponse, chat_resp: NvCreateChatCompletionResponse,
model: &str, model: &str,
api_context: Option<&crate::protocols::unified::AnthropicContext>,
) -> AnthropicMessageResponse { ) -> AnthropicMessageResponse {
let _ = api_context; // Available for future enrichment (service_tier, etc.)
let msg_id = format!("msg_{}", Uuid::new_v4().simple()); let msg_id = format!("msg_{}", Uuid::new_v4().simple());
let choice = chat_resp.inner.choices.into_iter().next(); let choice = chat_resp.inner.choices.into_iter().next();
...@@ -853,7 +858,7 @@ mod tests { ...@@ -853,7 +858,7 @@ mod tests {
nvext: None, nvext: None,
}; };
let response = chat_completion_to_anthropic_response(chat_resp, "test-model"); let response = chat_completion_to_anthropic_response(chat_resp, "test-model", None);
assert!(response.id.starts_with("msg_")); assert!(response.id.starts_with("msg_"));
assert_eq!(response.object_type, "message"); assert_eq!(response.object_type, "message");
assert_eq!(response.role, "assistant"); assert_eq!(response.role, "assistant");
......
...@@ -37,7 +37,7 @@ pub struct AnnotatedDelta<R> { ...@@ -37,7 +37,7 @@ pub struct AnnotatedDelta<R> {
pub comment: Option<String>, pub comment: Option<String>,
} }
trait OpenAISamplingOptionsProvider { pub(crate) trait OpenAISamplingOptionsProvider {
fn get_temperature(&self) -> Option<f32>; fn get_temperature(&self) -> Option<f32>;
fn get_top_p(&self) -> Option<f32>; fn get_top_p(&self) -> Option<f32>;
...@@ -55,7 +55,7 @@ trait OpenAISamplingOptionsProvider { ...@@ -55,7 +55,7 @@ trait OpenAISamplingOptionsProvider {
fn nvext(&self) -> Option<&nvext::NvExt>; fn nvext(&self) -> Option<&nvext::NvExt>;
} }
trait OpenAIStopConditionsProvider { pub(crate) trait OpenAIStopConditionsProvider {
fn get_max_tokens(&self) -> Option<u32>; fn get_max_tokens(&self) -> Option<u32>;
fn get_min_tokens(&self) -> Option<u32>; fn get_min_tokens(&self) -> Option<u32>;
...@@ -82,7 +82,7 @@ trait OpenAIStopConditionsProvider { ...@@ -82,7 +82,7 @@ trait OpenAIStopConditionsProvider {
} }
} }
trait OpenAIOutputOptionsProvider { pub(crate) trait OpenAIOutputOptionsProvider {
fn get_logprobs(&self) -> Option<u32>; fn get_logprobs(&self) -> Option<u32>;
fn get_prompt_logprobs(&self) -> Option<u32>; fn get_prompt_logprobs(&self) -> Option<u32>;
......
...@@ -695,6 +695,7 @@ fn make_function_call(name: String, arguments: String) -> OutputItem { ...@@ -695,6 +695,7 @@ fn make_function_call(name: String, arguments: String) -> OutputItem {
pub fn chat_completion_to_response( pub fn chat_completion_to_response(
nv_resp: NvCreateChatCompletionResponse, nv_resp: NvCreateChatCompletionResponse,
params: &ResponseParams, params: &ResponseParams,
api_context: Option<&crate::protocols::unified::ResponsesContext>,
) -> Result<NvResponse, anyhow::Error> { ) -> Result<NvResponse, anyhow::Error> {
let nvext = nv_resp.nvext.clone(); let nvext = nv_resp.nvext.clone();
let chat_resp = nv_resp.inner; let chat_resp = nv_resp.inner;
...@@ -814,7 +815,10 @@ pub fn chat_completion_to_response( ...@@ -814,7 +815,10 @@ pub fn chat_completion_to_response(
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
// Echo actual request values, falling back to spec defaults. // Echo actual request values, falling back to spec defaults.
// store: false because this branch does not persist responses. // store: false because this branch does not persist responses.
store: params.store.or(Some(false)), store: api_context
.map(|ctx| ctx.store)
.or(params.store)
.or(Some(false)),
temperature: params.temperature.or(Some(1.0)), temperature: params.temperature.or(Some(1.0)),
text: Some(params.text.clone().unwrap_or(ResponseTextParam { text: Some(params.text.clone().unwrap_or(ResponseTextParam {
format: TextResponseFormatConfiguration::Text, format: TextResponseFormatConfiguration::Text,
...@@ -841,7 +845,7 @@ pub fn chat_completion_to_response( ...@@ -841,7 +845,7 @@ pub fn chat_completion_to_response(
instructions: params.instructions.clone().map(Instructions::Text), instructions: params.instructions.clone().map(Instructions::Text),
max_output_tokens: params.max_output_tokens, max_output_tokens: params.max_output_tokens,
max_tool_calls: None, max_tool_calls: None,
previous_response_id: None, previous_response_id: api_context.and_then(|ctx| ctx.previous_response_id.clone()),
prompt: None, prompt: None,
prompt_cache_key: None, prompt_cache_key: None,
prompt_cache_retention: None, prompt_cache_retention: None,
...@@ -1194,7 +1198,8 @@ mod tests { ...@@ -1194,7 +1198,8 @@ mod tests {
nvext: None, nvext: None,
}; };
let wrapped = chat_completion_to_response(chat_resp, &ResponseParams::default()).unwrap(); let wrapped =
chat_completion_to_response(chat_resp, &ResponseParams::default(), None).unwrap();
assert_eq!(wrapped.inner.model, "llama-3.1-8b-instruct"); assert_eq!(wrapped.inner.model, "llama-3.1-8b-instruct");
assert_eq!(wrapped.inner.status, Status::Completed); assert_eq!(wrapped.inner.status, Status::Completed);
...@@ -1254,7 +1259,8 @@ mod tests { ...@@ -1254,7 +1259,8 @@ mod tests {
nvext: None, nvext: None,
}; };
let wrapped = chat_completion_to_response(chat_resp, &ResponseParams::default()).unwrap(); let wrapped =
chat_completion_to_response(chat_resp, &ResponseParams::default(), None).unwrap();
assert_eq!(wrapped.inner.output.len(), 1); assert_eq!(wrapped.inner.output.len(), 1);
match &wrapped.inner.output[0] { match &wrapped.inner.output[0] {
OutputItem::FunctionCall(fc) => { OutputItem::FunctionCall(fc) => {
...@@ -1449,7 +1455,7 @@ thinking ...@@ -1449,7 +1455,7 @@ thinking
nvext: None, nvext: None,
}; };
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
let reasoning = resp.inner.reasoning.unwrap(); let reasoning = resp.inner.reasoning.unwrap();
assert_eq!(reasoning.effort, Some(ReasoningEffort::High)); assert_eq!(reasoning.effort, Some(ReasoningEffort::High));
} }
...@@ -1482,7 +1488,7 @@ thinking ...@@ -1482,7 +1488,7 @@ thinking
nvext: None, nvext: None,
}; };
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
let text = resp.inner.text.unwrap(); let text = resp.inner.text.unwrap();
assert_eq!(text.format, TextResponseFormatConfiguration::JsonObject); assert_eq!(text.format, TextResponseFormatConfiguration::JsonObject);
} }
...@@ -1510,7 +1516,7 @@ thinking ...@@ -1510,7 +1516,7 @@ thinking
nvext: None, nvext: None,
}; };
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
assert_eq!(resp.inner.service_tier, Some(ServiceTier::Flex)); assert_eq!(resp.inner.service_tier, Some(ServiceTier::Flex));
} }
...@@ -1598,7 +1604,7 @@ thinking ...@@ -1598,7 +1604,7 @@ thinking
fn test_include_logprobs_stripped_by_default() { fn test_include_logprobs_stripped_by_default() {
let chat_resp = make_chat_resp_with_text("hello"); let chat_resp = make_chat_resp_with_text("hello");
let params = ResponseParams::default(); let params = ResponseParams::default();
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
for item in &resp.inner.output { for item in &resp.inner.output {
if let OutputItem::Message(msg) = item { if let OutputItem::Message(msg) = item {
...@@ -1623,7 +1629,7 @@ thinking ...@@ -1623,7 +1629,7 @@ thinking
include: Some(vec![IncludeEnum::MessageOutputTextLogprobs]), include: Some(vec![IncludeEnum::MessageOutputTextLogprobs]),
..Default::default() ..Default::default()
}; };
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
let mut found_text = false; let mut found_text = false;
for item in &resp.inner.output { for item in &resp.inner.output {
...@@ -1651,7 +1657,7 @@ thinking ...@@ -1651,7 +1657,7 @@ thinking
truncation: Some(Truncation::Auto), truncation: Some(Truncation::Auto),
..Default::default() ..Default::default()
}; };
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
assert_eq!(resp.inner.truncation, Some(Truncation::Auto)); assert_eq!(resp.inner.truncation, Some(Truncation::Auto));
} }
...@@ -1659,7 +1665,7 @@ thinking ...@@ -1659,7 +1665,7 @@ thinking
fn test_truncation_defaults_to_disabled() { fn test_truncation_defaults_to_disabled() {
let chat_resp = make_chat_resp_with_text("hello"); let chat_resp = make_chat_resp_with_text("hello");
let params = ResponseParams::default(); let params = ResponseParams::default();
let resp = chat_completion_to_response(chat_resp, &params).unwrap(); let resp = chat_completion_to_response(chat_resp, &params, None).unwrap();
assert_eq!(resp.inner.truncation, Some(Truncation::Disabled)); assert_eq!(resp.inner.truncation, Some(Truncation::Disabled));
} }
} }
...@@ -28,12 +28,15 @@ use dynamo_async_openai::types::ChatCompletionMessageContent; ...@@ -28,12 +28,15 @@ use dynamo_async_openai::types::ChatCompletionMessageContent;
use super::ResponseParams; use super::ResponseParams;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use crate::protocols::unified::ResponsesContext;
/// State machine that converts a chat completion stream into Responses API events. /// State machine that converts a chat completion stream into Responses API events.
pub struct ResponseStreamConverter { pub struct ResponseStreamConverter {
response_id: String, response_id: String,
model: String, model: String,
params: ResponseParams, params: ResponseParams,
/// Preserved Responses API-specific request context for faithful response reconstruction.
api_context: Option<ResponsesContext>,
created_at: u64, created_at: u64,
sequence_number: u64, sequence_number: u64,
// Text message tracking // Text message tracking
...@@ -72,6 +75,7 @@ impl ResponseStreamConverter { ...@@ -72,6 +75,7 @@ impl ResponseStreamConverter {
response_id: format!("resp_{}", Uuid::new_v4().simple()), response_id: format!("resp_{}", Uuid::new_v4().simple()),
model, model,
params, params,
api_context: None,
created_at, created_at,
sequence_number: 0, sequence_number: 0,
message_item_id: format!("msg_{}", Uuid::new_v4().simple()), message_item_id: format!("msg_{}", Uuid::new_v4().simple()),
...@@ -84,6 +88,12 @@ impl ResponseStreamConverter { ...@@ -84,6 +88,12 @@ impl ResponseStreamConverter {
} }
} }
pub fn with_context(model: String, params: ResponseParams, context: ResponsesContext) -> Self {
let mut converter = Self::new(model, params);
converter.api_context = Some(context);
converter
}
fn next_seq(&mut self) -> u64 { fn next_seq(&mut self) -> u64 {
let seq = self.sequence_number; let seq = self.sequence_number;
self.sequence_number += 1; self.sequence_number += 1;
...@@ -116,7 +126,12 @@ impl ResponseStreamConverter { ...@@ -116,7 +126,12 @@ impl ResponseStreamConverter {
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
// store: false because this branch does not persist responses. // store: false because this branch does not persist responses.
store: self.params.store.or(Some(false)), store: self
.api_context
.as_ref()
.map(|ctx| ctx.store)
.or(self.params.store)
.or(Some(false)),
temperature: self.params.temperature.or(Some(1.0)), temperature: self.params.temperature.or(Some(1.0)),
text: Some(self.params.text.clone().unwrap_or(ResponseTextParam { text: Some(self.params.text.clone().unwrap_or(ResponseTextParam {
format: TextResponseFormatConfiguration::Text, format: TextResponseFormatConfiguration::Text,
...@@ -144,7 +159,10 @@ impl ResponseStreamConverter { ...@@ -144,7 +159,10 @@ impl ResponseStreamConverter {
instructions: self.params.instructions.clone().map(Instructions::Text), instructions: self.params.instructions.clone().map(Instructions::Text),
max_output_tokens: self.params.max_output_tokens, max_output_tokens: self.params.max_output_tokens,
max_tool_calls: None, max_tool_calls: None,
previous_response_id: None, previous_response_id: self
.api_context
.as_ref()
.and_then(|ctx| ctx.previous_response_id.clone()),
prompt: None, prompt: None,
prompt_cache_key: None, prompt_cache_key: None,
prompt_cache_retention: None, prompt_cache_retention: None,
...@@ -654,6 +672,7 @@ fn get_event_type(event: &ResponseStreamEvent) -> &'static str { ...@@ -654,6 +672,7 @@ fn get_event_type(event: &ResponseStreamEvent) -> &'static str {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::unified::ResponsesContext;
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk, ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, ChatCompletionToolType, FunctionCallStream, ChatCompletionStreamResponseDelta, ChatCompletionToolType, FunctionCallStream,
...@@ -912,4 +931,41 @@ mod tests { ...@@ -912,4 +931,41 @@ mod tests {
"output_item.done inline after text: {tool_types:?}" "output_item.done inline after text: {tool_types:?}"
); );
} }
/// Verify that `with_context` populates `previous_response_id` and `store`
/// in the generated Response objects.
#[test]
fn test_with_context_enriches_response() {
let ctx = ResponsesContext {
previous_response_id: Some("resp_prev_123".to_string()),
store: true,
..Default::default()
};
let params = ResponseParams::default();
let mut conv = ResponseStreamConverter::with_context("test-model".into(), params, ctx);
// Process one text chunk so there's output
let _ = conv.emit_start_events();
let _ = conv.process_chunk(&text_chunk("Hello"));
let _end_events = conv.emit_end_events();
// Verify the Response object carries the context values through
let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(
response.previous_response_id.as_deref(),
Some("resp_prev_123")
);
assert_eq!(response.store, Some(true));
}
/// Without context, previous_response_id is None and store defaults to false.
#[test]
fn test_without_context_defaults() {
let params = ResponseParams::default();
let conv = ResponseStreamConverter::new("test-model".into(), params);
let response = conv.make_response(Status::Completed, vec![]);
assert_eq!(response.previous_response_id, None);
assert_eq!(response.store, Some(false));
}
} }
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment