Unverified Commit 7e970d44 authored by Konrad Nowicki's avatar Konrad Nowicki Committed by GitHub
Browse files

feat: image diffusion with SGLang diffusion (#5609)


Signed-off-by: default avatarKonrad Nowicki <knowicki@nvidia.com>
Co-authored-by: default avatardagil-nvidia <dagil@nvidia.com>
parent f3aa1e01
...@@ -37,6 +37,7 @@ bitflags! { ...@@ -37,6 +37,7 @@ bitflags! {
const Embedding = 1 << 2; const Embedding = 1 << 2;
const TensorBased = 1 << 3; const TensorBased = 1 << 3;
const Prefill = 1 << 4; const Prefill = 1 << 4;
const Images = 1 << 5;
} }
} }
...@@ -60,6 +61,9 @@ impl ModelType { ...@@ -60,6 +61,9 @@ impl ModelType {
pub fn supports_prefill(&self) -> bool { pub fn supports_prefill(&self) -> bool {
self.contains(ModelType::Prefill) self.contains(ModelType::Prefill)
} }
pub fn supports_images(&self) -> bool {
self.contains(ModelType::Images)
}
pub fn as_vec(&self) -> Vec<&'static str> { pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new(); let mut result = Vec::new();
...@@ -78,6 +82,9 @@ impl ModelType { ...@@ -78,6 +82,9 @@ impl ModelType {
if self.supports_prefill() { if self.supports_prefill() {
result.push("prefill"); result.push("prefill");
} }
if self.supports_images() {
result.push("images");
}
result result
} }
...@@ -100,6 +107,9 @@ impl ModelType { ...@@ -100,6 +107,9 @@ impl ModelType {
if self.supports_prefill() { if self.supports_prefill() {
result.push(ModelType::Prefill); result.push(ModelType::Prefill);
} }
if self.supports_images() {
result.push(ModelType::Images);
}
result result
} }
...@@ -116,6 +126,9 @@ impl ModelType { ...@@ -116,6 +126,9 @@ impl ModelType {
if self.contains(Self::Embedding) { if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding); endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
} }
if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images);
}
// [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type, // [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type,
// current use of endpoint type is LLM specific and so does the HTTP // current use of endpoint type is LLM specific and so does the HTTP
// server that uses it. // server that uses it.
......
...@@ -14,6 +14,7 @@ pub mod chat_completions; ...@@ -14,6 +14,7 @@ pub mod chat_completions;
pub mod common_ext; pub mod common_ext;
pub mod completions; pub mod completions;
pub mod embeddings; pub mod embeddings;
pub mod images;
pub mod models; pub mod models;
pub mod nvext; pub mod nvext;
pub mod responses; pub mod responses;
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use validator::Validate;
mod aggregator;
mod nvext;
pub use aggregator::DeltaAggregator;
pub use nvext::{NvExt, NvExtProvider};
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateImageRequest {
#[serde(flatten)]
pub inner: dynamo_async_openai::types::CreateImageRequest,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
/// A response structure for image generation responses, embedding OpenAI's
/// `ImagesResponse`.
///
/// # Fields
/// - `inner`: The base OpenAI image response, embedded using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvImagesResponse {
#[serde(flatten)]
pub inner: dynamo_async_openai::types::ImagesResponse,
}
impl NvImagesResponse {
pub fn empty() -> Self {
Self {
inner: dynamo_async_openai::types::ImagesResponse {
created: 0,
data: vec![],
},
}
}
}
/// Implements `NvExtProvider` for `NvCreateImageRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateImageRequest {
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
/// Implements `AnnotationsProvider` for `NvCreateImageRequest`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateImageRequest {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use futures::{Stream, StreamExt};
use crate::types::Annotated;
use super::NvImagesResponse;
/// Aggregator for combining image response deltas into a final response.
#[derive(Debug)]
pub struct DeltaAggregator {
response: Option<NvImagesResponse>,
error: Option<String>,
}
impl Default for DeltaAggregator {
/// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
fn default() -> Self {
Self::new()
}
}
impl DeltaAggregator {
pub fn new() -> Self {
DeltaAggregator {
response: None,
error: None,
}
}
/// Aggregates a stream of annotated image responses into a final response.
pub async fn apply(
stream: impl Stream<Item = Annotated<NvImagesResponse>>,
) -> Result<NvImagesResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// Attempt to unwrap the delta, capturing any errors.
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
aggregator.error = Some(error);
return aggregator;
}
};
if aggregator.error.is_none()
&& let Some(response) = delta.data
{
// For images, we typically expect a single complete response
// or we accumulate data from multiple responses
match &mut aggregator.response {
Some(existing) => {
// Merge image data if we have multiple responses
existing.inner.data.extend(response.inner.data);
}
None => {
aggregator.response = Some(response);
}
}
}
aggregator
})
.await;
// Return early if an error was encountered.
if let Some(error) = aggregator.error {
return Err(error);
}
// Return the aggregated response or an empty response if none was found.
Ok(aggregator.response.unwrap_or_else(NvImagesResponse::empty))
}
}
impl NvImagesResponse {
/// Aggregates an annotated stream of image responses into a final response.
///
/// # Arguments
/// * `stream` - A stream of annotated image responses.
///
/// # Returns
/// * `Ok(NvImagesResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvImagesResponse>>,
) -> Result<NvImagesResponse, String> {
DeltaAggregator::apply(stream).await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::{Validate, ValidationError};
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
}
/// NVIDIA extensions to the OpenAI Images API
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt {
/// Annotations
/// User requests triggers which result in the request issue back out-of-band information in the SSE
/// stream using the `event:` field.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
/// A text description of the undesired image(s).
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub negative_prompt: Option<String>,
/// The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub num_inference_steps: Option<u8>,
/// The CFG scale. Higher values usually lead to more coherent images.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guidance_scale: Option<f32>,
/// The seed for the random number generator.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub seed: Option<u32>,
}
impl Default for NvExt {
fn default() -> Self {
NvExt::builder().build().unwrap()
}
}
impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("annotations should always be Some(Vec)")
.push(annotation.into());
self
}
}
...@@ -59,6 +59,30 @@ pub mod openai { ...@@ -59,6 +59,30 @@ pub mod openai {
pub type OpenAIEmbeddingsStreamingEngine = pub type OpenAIEmbeddingsStreamingEngine =
ServerStreamingEngine<NvCreateEmbeddingRequest, Annotated<NvCreateEmbeddingResponse>>; ServerStreamingEngine<NvCreateEmbeddingRequest, Annotated<NvCreateEmbeddingResponse>>;
} }
pub mod images {
use super::*;
pub use protocols::openai::images::{NvCreateImageRequest, NvImagesResponse};
/// A [`UnaryEngine`] implementation for the OpenAI Images API
pub type OpenAIImagesUnaryEngine = UnaryEngine<NvCreateImageRequest, NvImagesResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Images API.
///
/// **Note**: This "streaming" refers to the internal routing/distribution architecture,
/// NOT client-facing Server-Sent Events (SSE) streaming. Image generation does not
/// support progressive streaming to clients - images are generated completely and
/// returned as finished artifacts (URLs or base64).
///
/// The HTTP endpoint folds this stream into a single response before returning to clients,
/// similar to how embeddings work. The streaming infrastructure is used for:
/// - Consistent routing architecture across all model types
/// - Request distribution via PushRouter
/// - Worker fault detection and load balancing
pub type OpenAIImagesStreamingEngine =
ServerStreamingEngine<NvCreateImageRequest, Annotated<NvImagesResponse>>;
}
} }
pub mod generic { pub mod generic {
......
...@@ -217,6 +217,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu ...@@ -217,6 +217,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu
Endpoint::Embeddings => todo!(), Endpoint::Embeddings => todo!(),
Endpoint::Responses => todo!(), Endpoint::Responses => todo!(),
Endpoint::Tensor => todo!(), Endpoint::Tensor => todo!(),
Endpoint::Images => todo!(),
}; };
let request_type = match request_type { let request_type = match request_type {
......
...@@ -62,7 +62,7 @@ vllm = [ ...@@ -62,7 +62,7 @@ vllm = [
sglang = [ sglang = [
"uvloop", "uvloop",
"sglang==0.5.8", "sglang[diffusion]==0.5.8",
"nixl[cu12]<=0.9.0", "nixl[cu12]<=0.9.0",
"cupy-cuda12x>=13.0.0", "cupy-cuda12x>=13.0.0",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment