"examples/backends/vscode:/vscode.git/clone" did not exist on "fb60cdc56fffdb1bf43b5bb221e5afcd8030b053"
Unverified Commit 9f76d060 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: text to image vLLM Omni (#5912)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d14d6ff4
...@@ -270,7 +270,7 @@ impl DeltaGenerator { ...@@ -270,7 +270,7 @@ impl DeltaGenerator {
stop_reason: Option<dynamo_async_openai::types::StopReason>, stop_reason: Option<dynamo_async_openai::types::StopReason>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: text, content: text.map(dynamo_async_openai::types::ChatCompletionMessageContent::Text),
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: if self.msg_counter == 0 { role: if self.msg_counter == 0 {
......
...@@ -112,7 +112,9 @@ fn create_choice_stream( ...@@ -112,7 +112,9 @@ fn create_choice_stream(
index, index,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role, role,
content: Some(content.to_string()), content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(content.to_string()),
),
tool_calls, tool_calls,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -533,23 +535,32 @@ impl JailedStream { ...@@ -533,23 +535,32 @@ impl JailedStream {
// Process each choice independently using the new architecture // Process each choice independently using the new architecture
for choice in &chat_response.choices { for choice in &chat_response.choices {
if let Some(ref content) = choice.delta.content { if let Some(ref content) = choice.delta.content {
let starts_jailed = matches!(self.jail_mode, JailMode::Immediate { .. }); // Jailing only applies to text content
let choice_state = choice_states.get_or_create_state(choice.index, starts_jailed); let text_content = match content {
dynamo_async_openai::types::ChatCompletionMessageContent::Text(text) => Some(text.as_str()),
// Store metadata when any choice becomes jailed (first time only) dynamo_async_openai::types::ChatCompletionMessageContent::Parts(_) => None,
if !choice_state.is_jailed && self.should_start_jail(content) };
&& last_annotated_id.is_none() {
last_annotated_id = response.id.clone(); if let Some(text) = text_content {
last_annotated_event = response.event.clone(); let starts_jailed = matches!(self.jail_mode, JailMode::Immediate { .. });
last_annotated_comment = response.comment.clone(); let choice_state = choice_states.get_or_create_state(choice.index, starts_jailed);
}
// Store metadata when any choice becomes jailed (first time only)
if !choice_state.is_jailed && self.should_start_jail(text)
&& last_annotated_id.is_none() {
last_annotated_id = response.id.clone();
last_annotated_event = response.event.clone();
last_annotated_comment = response.comment.clone();
}
// Track actual stream finish reason in the choice state // Track actual stream finish reason in the choice state
choice_state.stream_finish_reason = choice.finish_reason; choice_state.stream_finish_reason = choice.finish_reason;
// Process this choice and get emissions // Process this choice and get emissions
let emissions = choice_state.process_content(choice, content, &self).await; let emissions = choice_state.process_content(choice, text, &self).await;
all_emissions.extend(emissions); all_emissions.extend(emissions);
}
// For multimodal content, pass through unchanged (no jailing)
} else { } else {
// Handle choices without content (e.g., final chunks with finish_reason) // Handle choices without content (e.g., final chunks with finish_reason)
// Only filter out if this choice was ever jailed and lacks role // Only filter out if this choice was ever jailed and lacks role
......
...@@ -222,8 +222,20 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse { ...@@ -222,8 +222,20 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
.and_then(|choice| choice.message.content) .and_then(|choice| choice.message.content)
.unwrap_or_else(|| { .unwrap_or_else(|| {
tracing::warn!("No choices in chat completion response, using empty content"); tracing::warn!("No choices in chat completion response, using empty content");
String::new() dynamo_async_openai::types::ChatCompletionMessageContent::Text(String::new())
}); });
// Extract text from content (only handle text for responses API)
let text_content = match content_text {
dynamo_async_openai::types::ChatCompletionMessageContent::Text(text) => text,
dynamo_async_openai::types::ChatCompletionMessageContent::Parts(_) => {
tracing::warn!(
"Multimodal content in responses API not yet supported, using placeholder"
);
"[multimodal content]".to_string()
}
};
let message_id = format!("msg_{}", Uuid::new_v4().simple()); let message_id = format!("msg_{}", Uuid::new_v4().simple());
let response_id = format!("resp_{}", Uuid::new_v4().simple()); let response_id = format!("resp_{}", Uuid::new_v4().simple());
...@@ -232,7 +244,7 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse { ...@@ -232,7 +244,7 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
role: ResponseRole::Assistant, role: ResponseRole::Assistant,
status: OutputStatus::Completed, status: OutputStatus::Completed,
content: vec![Content::OutputText(OutputText { content: vec![Content::OutputText(OutputText {
text: content_text, text: text_content,
annotations: vec![], annotations: vec![],
})], })],
})]; })];
...@@ -363,7 +375,11 @@ mod tests { ...@@ -363,7 +375,11 @@ mod tests {
choices: vec![dynamo_async_openai::types::ChatChoice { choices: vec![dynamo_async_openai::types::ChatChoice {
index: 0, index: 0,
message: dynamo_async_openai::types::ChatCompletionResponseMessage { message: dynamo_async_openai::types::ChatCompletionResponseMessage {
content: Some("This is a reply".into()), content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"This is a reply".to_string(),
),
),
refusal: None, refusal: None,
tool_calls: None, tool_calls: None,
role: dynamo_async_openai::types::Role::Assistant, role: dynamo_async_openai::types::Role::Assistant,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::ChatCompletionMessageContent;
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
ContentProvider, DataStream, ContentProvider, DataStream,
codec::{Message, SseCodecError, create_message_stream}, codec::{Message, SseCodecError, create_message_stream},
...@@ -12,6 +13,13 @@ use dynamo_llm::protocols::{ ...@@ -12,6 +13,13 @@ use dynamo_llm::protocols::{
}; };
use futures::StreamExt; use futures::StreamExt;
fn get_text(content: &ChatCompletionMessageContent) -> &str {
match content {
ChatCompletionMessageContent::Text(text) => text.as_str(),
ChatCompletionMessageContent::Parts(_) => "",
}
}
const CMPL_ROOT_PATH: &str = "tests/data/replays/meta/llama-3.1-8b-instruct/completions"; const CMPL_ROOT_PATH: &str = "tests/data/replays/meta/llama-3.1-8b-instruct/completions";
const CHAT_ROOT_PATH: &str = "tests/data/replays/meta/llama-3.1-8b-instruct/chat_completions"; const CHAT_ROOT_PATH: &str = "tests/data/replays/meta/llama-3.1-8b-instruct/chat_completions";
...@@ -35,16 +43,17 @@ async fn test_openai_chat_stream() { ...@@ -35,16 +43,17 @@ async fn test_openai_chat_stream() {
// todo: provide a cleaner way to extract the content from choices // todo: provide a cleaner way to extract the content from choices
assert_eq!( assert_eq!(
result get_text(
.choices result
.first() .choices
.unwrap() .first()
.message .unwrap()
.content .message
.clone() .content
.expect("there to be content"), .as_ref()
.expect("there to be content")
),
"Deep learning is a subfield of machine learning that involves the use of artificial" "Deep learning is a subfield of machine learning that involves the use of artificial"
.to_string()
); );
} }
...@@ -59,15 +68,17 @@ async fn test_openai_chat_edge_case_multi_line_data() { ...@@ -59,15 +68,17 @@ async fn test_openai_chat_edge_case_multi_line_data() {
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
result get_text(
.choices result
.first() .choices
.unwrap() .first()
.message .unwrap()
.content .message
.clone() .content
.expect("there to be content"), .as_ref()
"Deep learning".to_string() .expect("there to be content")
),
"Deep learning"
); );
} }
...@@ -82,15 +93,17 @@ async fn test_openai_chat_edge_case_comments_per_response() { ...@@ -82,15 +93,17 @@ async fn test_openai_chat_edge_case_comments_per_response() {
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
result get_text(
.choices result
.first() .choices
.unwrap() .first()
.message .unwrap()
.content .message
.clone() .content
.expect("there to be content"), .as_ref()
"Deep learning".to_string() .expect("there to be content")
),
"Deep learning"
); );
} }
......
...@@ -11,8 +11,8 @@ use dynamo_llm::perf::{RecordedStream, TimestampedResponse}; ...@@ -11,8 +11,8 @@ use dynamo_llm::perf::{RecordedStream, TimestampedResponse};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta, ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageContent,
ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs, ChatCompletionStreamResponseDelta, ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
}; };
// Type aliases to simplify complex test data structures // Type aliases to simplify complex test data structures
...@@ -380,7 +380,7 @@ fn create_response_with_linear_probs( ...@@ -380,7 +380,7 @@ fn create_response_with_linear_probs(
let choice = ChatChoiceStream { let choice = ChatChoiceStream {
index: 0, index: 0,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some(_content.to_string()), content: Some(ChatCompletionMessageContent::Text(_content.to_string())),
#[expect(deprecated)] #[expect(deprecated)]
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
...@@ -460,7 +460,7 @@ fn create_multi_choice_response( ...@@ -460,7 +460,7 @@ fn create_multi_choice_response(
ChatChoiceStream { ChatChoiceStream {
index: choice_idx as u32, index: choice_idx as u32,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()), content: Some(ChatCompletionMessageContent::Text("test".to_string())),
#[expect(deprecated)] #[expect(deprecated)]
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
......
...@@ -16,6 +16,15 @@ mod tests { ...@@ -16,6 +16,15 @@ mod tests {
// Test utilities module - shared test infrastructure // Test utilities module - shared test infrastructure
pub(crate) mod test_utils { pub(crate) mod test_utils {
use super::*; use super::*;
use dynamo_async_openai::types::ChatCompletionMessageContent;
/// Helper to extract text from ChatCompletionMessageContent
pub fn extract_text(content: &ChatCompletionMessageContent) -> &str {
match content {
ChatCompletionMessageContent::Text(text) => text.as_str(),
ChatCompletionMessageContent::Parts(_) => "",
}
}
/// Helper function to create a mock chat response chunk /// Helper function to create a mock chat response chunk
pub fn create_mock_response_chunk( pub fn create_mock_response_chunk(
...@@ -27,7 +36,7 @@ mod tests { ...@@ -27,7 +36,7 @@ mod tests {
index, index,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some(content), content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -111,7 +120,7 @@ mod tests { ...@@ -111,7 +120,7 @@ mod tests {
index, index,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some(content), content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -154,7 +163,7 @@ mod tests { ...@@ -154,7 +163,7 @@ mod tests {
index, index,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some(content), content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -245,9 +254,11 @@ mod tests { ...@@ -245,9 +254,11 @@ mod tests {
.expect("Expected content in result"); .expect("Expected content in result");
assert_eq!( assert_eq!(
content, expected, extract_text(content),
expected,
"Content mismatch: expected '{}', got '{}'", "Content mismatch: expected '{}', got '{}'",
expected, content expected,
extract_text(content)
); );
} }
...@@ -301,7 +312,11 @@ mod tests { ...@@ -301,7 +312,11 @@ mod tests {
{ {
assert!( assert!(
choice.delta.content.is_none() choice.delta.content.is_none()
|| choice.delta.content.as_ref().unwrap().is_empty(), || choice.delta.content.as_ref().is_none_or(|c| match c {
dynamo_async_openai::types::ChatCompletionMessageContent::Text(t) =>
t.is_empty(),
_ => false,
}),
"Expected no content but got: {:?}", "Expected no content but got: {:?}",
choice.delta.content choice.delta.content
); );
...@@ -326,7 +341,7 @@ mod tests { ...@@ -326,7 +341,7 @@ mod tests {
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
}) })
.cloned() .map(extract_text)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("") .join("")
} }
...@@ -338,7 +353,10 @@ mod tests { ...@@ -338,7 +353,10 @@ mod tests {
.as_ref() .as_ref()
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
.cloned() .and_then(|content| match content {
ChatCompletionMessageContent::Text(text) => Some(text.clone()),
ChatCompletionMessageContent::Parts(_) => None,
})
.unwrap_or_default() .unwrap_or_default()
} }
...@@ -361,7 +379,7 @@ mod tests { ...@@ -361,7 +379,7 @@ mod tests {
.as_ref() .as_ref()
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
.map(|content| !content.is_empty()) .map(|content| !extract_text(content).is_empty())
.unwrap_or(false) .unwrap_or(false)
} }
} }
...@@ -402,7 +420,8 @@ mod tests { ...@@ -402,7 +420,8 @@ mod tests {
results[0].data.as_ref().unwrap().choices[0] results[0].data.as_ref().unwrap().choices[0]
.delta .delta
.content .content
.as_deref(), .as_ref()
.map(extract_text),
Some("Hello ") Some("Hello ")
); );
...@@ -410,9 +429,7 @@ mod tests { ...@@ -410,9 +429,7 @@ mod tests {
let unjailed_content = &results[1].data.as_ref().unwrap().choices[0].delta.content; let unjailed_content = &results[1].data.as_ref().unwrap().choices[0].delta.content;
assert!(unjailed_content.is_some()); assert!(unjailed_content.is_some());
assert!( assert!(
unjailed_content extract_text(unjailed_content.as_ref().unwrap())
.as_ref()
.unwrap()
.contains("<jail>This is jailed content</jail>") .contains("<jail>This is jailed content</jail>")
); );
...@@ -421,7 +438,8 @@ mod tests { ...@@ -421,7 +438,8 @@ mod tests {
results[2].data.as_ref().unwrap().choices[0] results[2].data.as_ref().unwrap().choices[0]
.delta .delta
.content .content
.as_deref(), .as_ref()
.map(extract_text),
Some(" World") Some(" World")
); );
} }
...@@ -494,7 +512,8 @@ mod tests { ...@@ -494,7 +512,8 @@ mod tests {
results[0].data.as_ref().unwrap().choices[0] results[0].data.as_ref().unwrap().choices[0]
.delta .delta
.content .content
.as_deref(), .as_ref()
.map(extract_text),
Some("Normal text ") Some("Normal text ")
); );
...@@ -504,7 +523,7 @@ mod tests { ...@@ -504,7 +523,7 @@ mod tests {
.content .content
.as_ref() .as_ref()
.expect("Expected accumulated jailed content"); .expect("Expected accumulated jailed content");
assert!(jailed.contains("<jail><TOOLCALL>Jailed content</jail>")); assert!(extract_text(jailed).contains("<jail><TOOLCALL>Jailed content</jail>"));
} }
#[tokio::test] #[tokio::test]
...@@ -1298,11 +1317,11 @@ mod tests { ...@@ -1298,11 +1317,11 @@ mod tests {
assert!(content.is_some(), "Should have accumulated content"); assert!(content.is_some(), "Should have accumulated content");
let content = content.as_ref().unwrap(); let content = content.as_ref().unwrap();
assert!( assert!(
content.contains("<tool_call>"), test_utils::extract_text(content).contains("<tool_call>"),
"Should contain jail start marker in accumulated content" "Should contain jail start marker in accumulated content"
); );
assert!( assert!(
content.contains("incomplete_call"), test_utils::extract_text(content).contains("incomplete_call"),
"Should contain accumulated incomplete content" "Should contain accumulated incomplete content"
); );
} }
...@@ -1672,7 +1691,8 @@ mod tests { ...@@ -1672,7 +1691,8 @@ mod tests {
.as_ref() .as_ref()
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
content, "Hello, world!", extract_text(content),
"Hello, world!",
"Content chunk should have 'Hello, world!'" "Content chunk should have 'Hello, world!'"
); );
...@@ -1860,7 +1880,10 @@ mod tests { ...@@ -1860,7 +1880,10 @@ mod tests {
.as_ref() .as_ref()
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
.map(|content| content.contains("Need to use function get_current_weather.")) .map(|content| {
test_utils::extract_text(content)
.contains("Need to use function get_current_weather.")
})
.unwrap_or(false) .unwrap_or(false)
}); });
assert!(has_analysis_text, "Should contain extracted analysis text"); assert!(has_analysis_text, "Should contain extracted analysis text");
...@@ -1912,7 +1935,7 @@ mod tests { ...@@ -1912,7 +1935,7 @@ mod tests {
for choice in data.choices { for choice in data.choices {
if let Some(content) = choice.delta.content { if let Some(content) = choice.delta.content {
assert!( assert!(
!content.contains("<|tool▁calls▁end|>"), !test_utils::extract_text(&content).contains("<|tool▁calls▁end|>"),
"Should not contain deepseek special tokens in content" "Should not contain deepseek special tokens in content"
); );
} }
...@@ -1986,7 +2009,7 @@ mod tests { ...@@ -1986,7 +2009,7 @@ mod tests {
for choice in data.choices { for choice in data.choices {
if let Some(content) = choice.delta.content { if let Some(content) = choice.delta.content {
assert!( assert!(
!content.contains("<|tool▁calls▁end|>"), !test_utils::extract_text(&content).contains("<|tool▁calls▁end|>"),
"Should not contain deepseek special tokens in content" "Should not contain deepseek special tokens in content"
); );
} }
...@@ -2184,7 +2207,8 @@ mod tests { ...@@ -2184,7 +2207,8 @@ mod tests {
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
}) })
.filter(|content| { .filter(|content| {
content.contains("<tool_call>") || content.contains("should not jail") test_utils::extract_text(content).contains("<tool_call>")
|| test_utils::extract_text(content).contains("should not jail")
}) })
.collect(); .collect();
...@@ -2202,7 +2226,10 @@ mod tests { ...@@ -2202,7 +2226,10 @@ mod tests {
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
}) })
.find(|content| content.contains("[[START]]") && content.contains("jailed content")); .find(|content| {
test_utils::extract_text(content).contains("[[START]]")
&& test_utils::extract_text(content).contains("jailed content")
});
assert!( assert!(
jailed_chunk.is_some(), jailed_chunk.is_some(),
...@@ -2320,6 +2347,7 @@ mod tests { ...@@ -2320,6 +2347,7 @@ mod tests {
mod parallel_jail_tests { mod parallel_jail_tests {
use super::tests::test_utils; use super::tests::test_utils;
use super::*; use super::*;
use dynamo_async_openai::types::ChatCompletionMessageContent;
use futures::StreamExt; use futures::StreamExt;
use futures::stream; use futures::stream;
use serde_json::json; use serde_json::json;
...@@ -2337,7 +2365,7 @@ mod parallel_jail_tests { ...@@ -2337,7 +2365,7 @@ mod parallel_jail_tests {
index: i as u32, index: i as u32,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some(content), content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -2589,10 +2617,9 @@ mod parallel_jail_tests { ...@@ -2589,10 +2617,9 @@ mod parallel_jail_tests {
let normal_text_before = results.iter().find(|r| { let normal_text_before = results.iter().find(|r| {
r.data.as_ref().is_some_and(|d| { r.data.as_ref().is_some_and(|d| {
d.choices.iter().any(|c| { d.choices.iter().any(|c| {
c.delta c.delta.content.as_ref().is_some_and(|content| {
.content test_utils::extract_text(content).contains("I'll check the weather")
.as_ref() })
.is_some_and(|content| content.contains("I'll check the weather"))
}) })
}) })
}); });
...@@ -2619,10 +2646,9 @@ mod parallel_jail_tests { ...@@ -2619,10 +2646,9 @@ mod parallel_jail_tests {
let normal_text_after = results.iter().find(|r| { let normal_text_after = results.iter().find(|r| {
r.data.as_ref().is_some_and(|d| { r.data.as_ref().is_some_and(|d| {
d.choices.iter().any(|c| { d.choices.iter().any(|c| {
c.delta c.delta.content.as_ref().is_some_and(|content| {
.content test_utils::extract_text(content).contains("Let me get that information")
.as_ref() })
.is_some_and(|content| content.contains("Let me get that information"))
}) })
}) })
}); });
...@@ -2982,8 +3008,8 @@ mod parallel_jail_tests { ...@@ -2982,8 +3008,8 @@ mod parallel_jail_tests {
r.data.as_ref().is_some_and(|d| { r.data.as_ref().is_some_and(|d| {
d.choices.iter().any(|c| { d.choices.iter().any(|c| {
c.delta.content.as_ref().is_some_and(|content| { c.delta.content.as_ref().is_some_and(|content| {
content.contains("I'll help you") test_utils::extract_text(content).contains("I'll help you")
|| content.contains("don't need any tools") || test_utils::extract_text(content).contains("don't need any tools")
}) })
}) })
}) })
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::{ChatChoiceStream, ChatCompletionStreamResponseDelta, Role}; use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionStreamResponseDelta, Role,
};
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use futures::{StreamExt, stream}; use futures::{StreamExt, stream};
/// Helper to extract text from ChatCompletionMessageContent
fn get_text(content: &ChatCompletionMessageContent) -> &str {
match content {
ChatCompletionMessageContent::Text(text) => text.as_str(),
ChatCompletionMessageContent::Parts(_) => "",
}
}
/// Helper function to create a mock chat response chunk /// Helper function to create a mock chat response chunk
fn create_mock_response_chunk( fn create_mock_response_chunk(
content: String, content: String,
...@@ -17,7 +27,7 @@ fn create_mock_response_chunk( ...@@ -17,7 +27,7 @@ fn create_mock_response_chunk(
index: 0, index: 0,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some(content), content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -61,7 +71,7 @@ mod tests { ...@@ -61,7 +71,7 @@ mod tests {
match expected_content { match expected_content {
Some(expected) => { Some(expected) => {
assert_eq!( assert_eq!(
choice.delta.content.as_deref(), choice.delta.content.as_ref().map(get_text),
Some(expected), Some(expected),
"Content mismatch" "Content mismatch"
); );
...@@ -69,7 +79,7 @@ mod tests { ...@@ -69,7 +79,7 @@ mod tests {
None => { None => {
assert!( assert!(
choice.delta.content.is_none() choice.delta.content.is_none()
|| choice.delta.content.as_ref().unwrap().is_empty(), || get_text(choice.delta.content.as_ref().unwrap()).is_empty(),
"Expected content to be None or empty, got: {:?}", "Expected content to be None or empty, got: {:?}",
choice.delta.content choice.delta.content
); );
...@@ -260,7 +270,7 @@ mod tests { ...@@ -260,7 +270,7 @@ mod tests {
let output_choice = &output.data.as_ref().unwrap().choices[0]; let output_choice = &output.data.as_ref().unwrap().choices[0];
assert_choice( assert_choice(
output_choice, output_choice,
input_choice.delta.content.as_deref(), input_choice.delta.content.as_ref().map(get_text),
input_choice.delta.reasoning_content.as_deref(), input_choice.delta.reasoning_content.as_deref(),
); );
} }
...@@ -316,7 +326,8 @@ mod tests { ...@@ -316,7 +326,8 @@ mod tests {
"Should contain Mistral reasoning content" "Should contain Mistral reasoning content"
); );
assert!( assert!(
normal_content.contains("Let me think") || normal_content.contains("Here's my answer"), get_text(normal_content).contains("Let me think")
|| get_text(normal_content).contains("Here's my answer"),
"Should contain normal content" "Should contain normal content"
); );
} }
...@@ -379,7 +390,7 @@ mod tests { ...@@ -379,7 +390,7 @@ mod tests {
// Collect normal content // Collect normal content
if let Some(ref content) = choice.delta.content { if let Some(ref content) = choice.delta.content {
all_normal_content.push_str(content); all_normal_content.push_str(get_text(content));
} }
} }
} }
...@@ -450,8 +461,8 @@ mod tests { ...@@ -450,8 +461,8 @@ mod tests {
"Should contain Kimi reasoning content" "Should contain Kimi reasoning content"
); );
assert!( assert!(
normal_content.contains("Let me analyze") get_text(normal_content).contains("Let me analyze")
|| normal_content.contains("Here's my conclusion"), || get_text(normal_content).contains("Here's my conclusion"),
"Should contain normal content" "Should contain normal content"
); );
} }
...@@ -518,7 +529,7 @@ mod tests { ...@@ -518,7 +529,7 @@ mod tests {
// Collect normal content // Collect normal content
if let Some(ref content) = choice.delta.content { if let Some(ref content) = choice.delta.content {
all_normal_content.push_str(content); all_normal_content.push_str(get_text(content));
} }
// Check for tool calls // Check for tool calls
...@@ -624,7 +635,7 @@ mod tests { ...@@ -624,7 +635,7 @@ mod tests {
all_reasoning.push_str(reasoning); all_reasoning.push_str(reasoning);
} }
if let Some(ref content) = choice.delta.content { if let Some(ref content) = choice.delta.content {
all_normal_content.push_str(content); all_normal_content.push_str(get_text(content));
} }
if let Some(ref tool_calls) = choice.delta.tool_calls if let Some(ref tool_calls) = choice.delta.tool_calls
&& !tool_calls.is_empty() && !tool_calls.is_empty()
......
...@@ -26,7 +26,7 @@ across backends. ...@@ -26,7 +26,7 @@ across backends.
*/ */
use dynamo_async_openai::types::{ChatChoiceStream, FinishReason}; use dynamo_async_openai::types::{ChatChoiceStream, ChatCompletionMessageContent, FinishReason};
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
...@@ -35,6 +35,13 @@ use std::pin::Pin; ...@@ -35,6 +35,13 @@ use std::pin::Pin;
const DATA_ROOT_PATH: &str = "tests/data/"; const DATA_ROOT_PATH: &str = "tests/data/";
fn get_text(content: &ChatCompletionMessageContent) -> &str {
match content {
ChatCompletionMessageContent::Text(text) => text.as_str(),
ChatCompletionMessageContent::Parts(_) => "",
}
}
/// Test data structure containing expected results and stream data /// Test data structure containing expected results and stream data
struct TestData { struct TestData {
expected_normal_content: String, expected_normal_content: String,
...@@ -230,7 +237,7 @@ fn aggregate_content_from_chunks( ...@@ -230,7 +237,7 @@ fn aggregate_content_from_chunks(
// Collect normal content // Collect normal content
if let Some(ref content) = choice.delta.content { if let Some(ref content) = choice.delta.content {
normal_content.push_str(content); normal_content.push_str(get_text(content));
} }
// Collect tool calls // Collect tool calls
......
...@@ -2,12 +2,21 @@ ...@@ -2,12 +2,21 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionMessageContent, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest,
FunctionName,
}; };
use dynamo_llm::protocols::common; use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput; use dynamo_llm::protocols::common::llm_backend::BackendOutput;
/// Helper to extract text from ChatCompletionMessageContent
fn get_text(content: &ChatCompletionMessageContent) -> &str {
match content {
ChatCompletionMessageContent::Text(text) => text.as_str(),
ChatCompletionMessageContent::Parts(_) => "",
}
}
use dynamo_llm::protocols::openai::DeltaGeneratorExt; use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
...@@ -153,7 +162,7 @@ async fn test_named_tool_choice_parses_json() { ...@@ -153,7 +162,7 @@ async fn test_named_tool_choice_parses_json() {
Some(dynamo_async_openai::types::FinishReason::Stop) Some(dynamo_async_openai::types::FinishReason::Stop)
); );
let delta = &choice.delta; let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); assert!(delta.content.is_none() || delta.content.as_ref().map(get_text) == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap(); let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
...@@ -195,7 +204,7 @@ async fn test_required_tool_choice_parses_json_array() { ...@@ -195,7 +204,7 @@ async fn test_required_tool_choice_parses_json_array() {
Some(dynamo_async_openai::types::FinishReason::ToolCalls) Some(dynamo_async_openai::types::FinishReason::ToolCalls)
); );
let delta = &choice.delta; let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); assert!(delta.content.is_none() || delta.content.as_ref().map(get_text) == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap(); let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2); assert_eq!(tool_calls.len(), 2);
...@@ -252,7 +261,7 @@ async fn test_tool_choice_parse_failure_returns_as_content() { ...@@ -252,7 +261,7 @@ async fn test_tool_choice_parse_failure_returns_as_content() {
// Jail stream behavior: if parsing fails, return accumulated content as-is // Jail stream behavior: if parsing fails, return accumulated content as-is
// This matches marker-based FC behavior // This matches marker-based FC behavior
assert_eq!(delta.content.as_deref(), Some("not-json")); assert_eq!(delta.content.as_ref().map(get_text), Some("not-json"));
assert!(delta.tool_calls.is_none()); assert!(delta.tool_calls.is_none());
} }
...@@ -434,7 +443,7 @@ fn test_no_tool_choice_outputs_normal_text() { ...@@ -434,7 +443,7 @@ fn test_no_tool_choice_outputs_normal_text() {
.expect("normal text"); .expect("normal text");
assert_eq!( assert_eq!(
response.choices[0].delta.content.as_deref(), response.choices[0].delta.content.as_ref().map(get_text),
Some("Hello world") Some("Hello world")
); );
assert!(response.choices[0].delta.tool_calls.is_none()); assert!(response.choices[0].delta.tool_calls.is_none());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment