"vllm/vscode:/vscode.git/clone" did not exist on "ac0675ff6b40768293ef4b87741b161f6cf4518b"
Unverified Commit 9b2b44e3 authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix(responses): accept assistant output_text messages without id/status in input (#6599)


Signed-off-by: default avatarMarko Kosec <mkosec@nvidia.com>
Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
Signed-off-by: default avatarVasilis Vagias <vvagias@nvidia.com>
Co-authored-by: default avatarvvagias <vasilis.n.vagias@gmail.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent abc02c68
......@@ -24,10 +24,11 @@ pub enum Role {
}
/// Status of input/output items.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, ToSchema)]
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Default, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum OutputStatus {
InProgress,
#[default]
Completed,
Incomplete,
}
......@@ -367,6 +368,8 @@ pub struct CustomToolCallOutput {
#[builder(build_fn(error = "OpenAIError"))]
pub struct EasyInputMessage {
/// The type of the message input. Always set to `message`.
/// Optional in the "easy" format — defaults to `message` when omitted.
#[serde(default)]
pub r#type: MessageType,
/// The role of the message input. One of `user`, `assistant`, `system`, or `developer`.
pub role: Role,
......@@ -423,6 +426,7 @@ pub enum EasyInputContent {
}
/// Parts of a message: text, image, file, or audio.
/// Also accepts `output_text` for replaying assistant turns in the "easy" input format.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContent {
......@@ -437,6 +441,11 @@ pub enum InputContent {
InputVideo(InputVideoContent),
/// An audio input to the model.
InputAudio(InputAudioContent),
/// An output text content item, accepted when replaying assistant messages
/// in the "easy" input format (role: assistant with output_text content).
OutputText(OutputTextContent),
/// A refusal content item, accepted when replaying assistant messages.
Refusal(RefusalContent),
}
/// Video content for input messages.
......@@ -894,6 +903,7 @@ pub struct ResponseTextParam {
/// Setting to `{ "type": "json_object" }` enables the older JSON mode, which
/// ensures the message the model generates is valid JSON. Using `json_schema`
/// is preferred for models that support it.
#[serde(default)]
pub format: TextResponseFormatConfiguration,
/// Constrains the verbosity of the model's response. Lower values will result in
......@@ -904,10 +914,11 @@ pub struct ResponseTextParam {
pub verbosity: Option<Verbosity>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, ToSchema)]
#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, ToSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextResponseFormatConfiguration {
/// Default response format. Used to generate text responses.
#[default]
Text,
/// JSON object response format. An older method of generating JSON responses.
/// Using `json_schema` is recommended for models that support it.
......@@ -1473,6 +1484,8 @@ pub struct ResponseLogProb {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema)]
pub struct OutputTextContent {
/// The annotations of the text output.
/// Defaults to empty when not provided (e.g., replaying assistant turns as input).
#[serde(default)]
pub annotations: Vec<Annotation>,
pub logprobs: Option<Vec<LogProb>>,
/// The text output from the model.
......@@ -1545,17 +1558,26 @@ pub struct RefusalContent {
}
/// A message generated by the model.
///
/// `id` and `status` use `#[serde(default)]` so that clients can feed back a
/// previous assistant message without those fields (e.g. multi-turn
/// conversations where the caller only has the `output_text` content).
/// The `MessageItem` enum is `#[serde(untagged)]` and tries `Output` first;
/// without defaults the missing fields would cause deserialization to fall
/// through to `Input`, which rejects `role: "assistant"`.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, ToSchema)]
pub struct OutputMessage {
/// The content of the output message.
pub content: Vec<OutputMessageContent>,
/// The unique ID of the output message.
pub id: String,
/// Optional when provided as input (e.g., replaying assistant turns in conversation history).
/// Always present in model-generated output.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// The role of the output message. Always `assistant`.
pub role: AssistantRole,
/// The status of the message input. One of `in_progress`, `completed`, or
/// `incomplete`. Populated when input items are returned via API.
pub status: OutputStatus,
/// Optional when provided as input (e.g., replaying assistant turns in conversation history).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub status: Option<OutputStatus>,
///// The type of the output message. Always `message`.
//pub r#type: MessageType,
}
......@@ -2841,3 +2863,94 @@ pub struct CompactResource {
/// Token accounting for the compaction pass, including cached, reasoning, and total tokens.
pub usage: ResponseUsage,
}
#[cfg(test)]
mod tests {
use super::*;
/// Issue #6: Assistant messages with output_text content should deserialize
/// without requiring `id` and `status` fields. Clients replay previous
/// assistant turns in conversation history without output metadata.
#[test]
fn test_assistant_output_text_without_id_status() {
let json = r#"{
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello!"}],
"type": "message"
}"#;
let item: InputItem = serde_json::from_str(json)
.expect("assistant output_text without id/status should deserialize");
match &item {
InputItem::Item(Item::Message(MessageItem::Output(out_msg))) => {
assert!(out_msg.id.is_none());
assert!(out_msg.status.is_none());
assert_eq!(out_msg.content.len(), 1);
}
other => panic!("expected OutputMessage, got {:?}", other),
}
}
/// Issue #6 extended: full multi-turn conversation with output_text history.
#[test]
fn test_multiturn_with_output_text_history() {
let json = r#"{
"model": "test-model",
"input": [
{"role": "user", "content": "hi", "type": "message"},
{
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello!"}],
"type": "message"
},
{"role": "user", "content": "bye", "type": "message"}
],
"stream": false
}"#;
let request: CreateResponse = serde_json::from_str(json)
.expect("multi-turn with output_text history should deserialize");
match &request.input {
InputParam::Items(items) => assert_eq!(items.len(), 3),
other => panic!("expected Items, got {:?}", other),
}
}
/// Issue #7: Reasoning items in the input array should deserialize.
#[test]
fn test_reasoning_item_in_input() {
let json = r#"{
"type": "reasoning",
"id": "rs_1",
"summary": [{"text": "thinking", "type": "summary_text"}]
}"#;
let item: InputItem =
serde_json::from_str(json).expect("reasoning item should deserialize");
match &item {
InputItem::Item(Item::Reasoning(r)) => {
assert_eq!(r.id, "rs_1");
assert_eq!(r.summary.len(), 1);
}
other => panic!("expected Reasoning item, got {:?}", other),
}
}
/// OutputMessage with id and status should still work (backwards compat).
#[test]
fn test_output_message_with_id_and_status() {
let json = r#"{
"role": "assistant",
"id": "msg_abc123",
"status": "completed",
"content": [{"type": "output_text", "text": "Hello!"}],
"type": "message"
}"#;
let item: InputItem = serde_json::from_str(json)
.expect("output message with id/status should still deserialize");
match &item {
InputItem::Item(Item::Message(MessageItem::Output(out_msg))) => {
assert_eq!(out_msg.id.as_deref(), Some("msg_abc123"));
assert_eq!(out_msg.status, Some(OutputStatus::Completed));
}
other => panic!("expected OutputMessage, got {:?}", other),
}
}
}
......@@ -1340,6 +1340,7 @@ async fn responses(
// Extract request parameters before into_parts() consumes the request.
// These are echoed back in the Response object per the OpenAI spec.
let response_params = ResponseParams {
model: request.inner.model.clone(),
temperature: request.inner.temperature,
top_p: request.inner.top_p,
max_output_tokens: request.inner.max_output_tokens,
......@@ -1347,6 +1348,11 @@ async fn responses(
tools: request.inner.tools.clone(),
tool_choice: request.inner.tool_choice.clone(),
instructions: request.inner.instructions.clone(),
reasoning: request.inner.reasoning.clone(),
text: request.inner.text.clone(),
service_tier: request.inner.service_tier,
include: request.inner.include.clone(),
truncation: request.inner.truncation,
};
let request_id = request.id().to_string();
let (orig_request, context) = request.into_parts();
......@@ -1367,11 +1373,14 @@ async fn responses(
err_response
})?;
// For non-streaming responses, we still use internal streaming for aggregation,
// but we set the chat completion stream flag appropriately.
if !streaming {
chat_request.inner.stream = Some(true); // Internal streaming for aggregation
}
// Always use internal streaming for aggregation.
// Set stream_options.include_usage so the backend sends token counts in the final chunk.
chat_request.inner.stream = Some(true);
chat_request.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
continuous_usage_stats: false,
});
let request = context.map(|mut _req| chat_request);
......@@ -1556,11 +1565,6 @@ pub fn validate_response_unsupported_fields(
VALIDATION_PREFIX.to_string() + "`background: true` is not supported.",
));
}
if inner.include.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`include` is not supported.",
));
}
if inner.previous_response_id.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`previous_response_id` is not supported.",
......@@ -1571,31 +1575,11 @@ pub fn validate_response_unsupported_fields(
VALIDATION_PREFIX.to_string() + "`prompt` is not supported.",
));
}
if inner.reasoning.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`reasoning` is not supported.",
));
}
if inner.service_tier.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`service_tier` is not supported.",
));
}
if inner.store == Some(true) {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`store: true` is not supported.",
));
}
if inner.text.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`text` is not supported.",
));
}
if inner.truncation.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`truncation` is not supported.",
));
}
None
}
......@@ -2063,10 +2047,7 @@ mod tests {
use crate::protocols::openai::common_ext::CommonExt;
use crate::protocols::openai::completions::NvCreateCompletionRequest;
use crate::protocols::openai::responses::NvCreateResponse;
use dynamo_async_openai::types::responses::{
CreateResponse, IncludeEnum, Input, PromptConfig, ServiceTier, TextConfig,
TextResponseFormat, Truncation,
};
use dynamo_async_openai::types::responses::{CreateResponse, Input, PromptConfig};
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
......@@ -2174,10 +2155,6 @@ mod tests {
#[allow(clippy::type_complexity)]
let unsupported_cases: Vec<(&str, Box<dyn FnOnce(&mut CreateResponse)>)> = vec![
("background", Box::new(|r| r.background = Some(true))),
(
"include",
Box::new(|r| r.include = Some(vec![IncludeEnum::FileSearchCallResults])),
),
(
"previous_response_id",
Box::new(|r| r.previous_response_id = Some("prev-id".into())),
......@@ -2192,28 +2169,7 @@ mod tests {
})
}),
),
(
"reasoning",
Box::new(|r| r.reasoning = Some(Default::default())),
),
(
"service_tier",
Box::new(|r| r.service_tier = Some(ServiceTier::Auto)),
),
("store", Box::new(|r| r.store = Some(true))),
(
"text",
Box::new(|r| {
r.text = Some(TextConfig {
format: TextResponseFormat::Text,
verbosity: None,
})
}),
),
(
"truncation",
Box::new(|r| r.truncation = Some(Truncation::Auto)),
),
];
for (field, set_field) in unsupported_cases {
......
......@@ -13,14 +13,14 @@ use std::time::{SystemTime, UNIX_EPOCH};
use axum::response::sse::Event;
use dynamo_async_openai::types::responses::{
AssistantRole, FunctionToolCall, Instructions, OutputContent, OutputItem, OutputMessage,
OutputMessageContent, OutputStatus, OutputTextContent, Response, ResponseCompletedEvent,
ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseCreatedEvent,
ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent,
AssistantRole, FunctionToolCall, InputTokenDetails, Instructions, OutputContent, OutputItem,
OutputMessage, OutputMessageContent, OutputStatus, OutputTextContent, OutputTokenDetails,
Response, ResponseCompletedEvent, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent,
ResponseCreatedEvent, ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent, ResponseInProgressEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseStreamEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextParam, ServiceTier, Status, TextResponseFormatConfiguration,
ToolChoiceOptions, ToolChoiceParam, Truncation,
ResponseTextDoneEvent, ResponseTextParam, ResponseUsage, ServiceTier, Status,
TextResponseFormatConfiguration, ToolChoiceOptions, ToolChoiceParam, Truncation,
};
use uuid::Uuid;
......@@ -45,6 +45,8 @@ pub struct ResponseStreamConverter {
function_call_items: Vec<FunctionCallState>,
// Output index counter
next_output_index: u32,
// Usage stats from the backend's final chunk
usage: Option<ResponseUsage>,
}
struct FunctionCallState {
......@@ -75,6 +77,7 @@ impl ResponseStreamConverter {
accumulated_text: String::new(),
function_call_items: Vec::new(),
next_output_index: 0,
usage: None,
}
}
......@@ -112,10 +115,10 @@ impl ResponseStreamConverter {
// store: false because this branch does not persist responses.
store: self.params.store.or(Some(false)),
temperature: self.params.temperature.or(Some(1.0)),
text: Some(ResponseTextParam {
text: Some(self.params.text.clone().unwrap_or(ResponseTextParam {
format: TextResponseFormatConfiguration::Text,
verbosity: None,
}),
})),
tool_choice: self
.params
.tool_choice
......@@ -129,7 +132,7 @@ impl ResponseStreamConverter {
.unwrap_or_default(),
),
top_p: self.params.top_p.or(Some(1.0)),
truncation: Some(Truncation::Disabled),
truncation: Some(self.params.truncation.unwrap_or(Truncation::Disabled)),
// Nullable required fields
billing: None,
conversation: None,
......@@ -142,11 +145,11 @@ impl ResponseStreamConverter {
prompt: None,
prompt_cache_key: None,
prompt_cache_retention: None,
reasoning: None,
reasoning: self.params.reasoning.clone(),
safety_identifier: None,
service_tier: Some(ServiceTier::Auto),
service_tier: Some(self.params.service_tier.unwrap_or(ServiceTier::Auto)),
top_logprobs: Some(0),
usage: None,
usage: self.usage.clone(),
}
}
......@@ -176,6 +179,29 @@ impl ResponseStreamConverter {
) -> Vec<Result<Event, anyhow::Error>> {
let mut events = Vec::new();
// Capture usage stats from the final chunk (sent when stream_options.include_usage=true)
if let Some(ref u) = chunk.usage {
self.usage = Some(ResponseUsage {
input_tokens: u.prompt_tokens,
input_tokens_details: InputTokenDetails {
cached_tokens: u
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens)
.unwrap_or(0),
},
output_tokens: u.completion_tokens,
output_tokens_details: OutputTokenDetails {
reasoning_tokens: u
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens)
.unwrap_or(0),
},
total_tokens: u.total_tokens,
});
}
for choice in &chunk.choices {
let delta = &choice.delta;
......@@ -203,10 +229,10 @@ impl ResponseStreamConverter {
sequence_number: self.next_seq(),
output_index,
item: OutputItem::Message(OutputMessage {
id: self.message_item_id.clone(),
id: Some(self.message_item_id.clone()),
content: vec![],
role: AssistantRole::Assistant,
status: OutputStatus::InProgress,
status: Some(OutputStatus::InProgress),
}),
},
);
......@@ -354,14 +380,14 @@ impl ResponseStreamConverter {
sequence_number: self.next_seq(),
output_index: self.message_output_index,
item: OutputItem::Message(OutputMessage {
id: self.message_item_id.clone(),
id: Some(self.message_item_id.clone()),
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
})],
role: AssistantRole::Assistant,
status: OutputStatus::Completed,
status: Some(OutputStatus::Completed),
}),
});
events.push(make_sse_event(&item_done));
......@@ -413,14 +439,14 @@ impl ResponseStreamConverter {
let mut output = Vec::new();
if self.message_started {
output.push(OutputItem::Message(OutputMessage {
id: self.message_item_id.clone(),
id: Some(self.message_item_id.clone()),
content: vec![OutputMessageContent::OutputText(OutputTextContent {
text: self.accumulated_text.clone(),
annotations: vec![],
logprobs: Some(vec![]),
})],
role: AssistantRole::Assistant,
status: OutputStatus::Completed,
status: Some(OutputStatus::Completed),
}));
}
for fc in &self.function_call_items {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment