Unverified Commit 225dad4d authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: add I2I (image-to-image) support for SGLang diffusion backend (#7870)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent ba2185a0
...@@ -65,6 +65,9 @@ class NvCreateImageRequest(BaseModel): ...@@ -65,6 +65,9 @@ class NvCreateImageRequest(BaseModel):
moderation: Optional[str] = None moderation: Optional[str] = None
"""Content moderation level: auto or low.""" """Content moderation level: auto or low."""
input_reference: Optional[str] = None
"""Optional image reference that guides generation (for I2I)."""
nvext: Optional[ImageNvExt] = None nvext: Optional[ImageNvExt] = None
"""NVIDIA extensions.""" """NVIDIA extensions."""
......
...@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field ...@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from dynamo.common.multimodal import TransferRequest from dynamo.common.multimodal import TransferRequest
from dynamo.common.protocols.image_protocol import ImageNvExt
TokenIdType = int TokenIdType = int
...@@ -143,18 +144,13 @@ class DisaggSglangMultimodalRequest(BaseModel): ...@@ -143,18 +144,13 @@ class DisaggSglangMultimodalRequest(BaseModel):
# ============================================================================ # ============================================================================
class NvExt(BaseModel):
"""NVIDIA extensions for image generation"""
negative_prompt: Optional[str] = None
num_inference_steps: Optional[int] = 50
guidance_scale: float = 7.5
seed: Optional[int] = None
annotations: Optional[list[str]] = None
class CreateImageRequest(BaseModel): class CreateImageRequest(BaseModel):
"""OpenAI /v1/images/generations compatible request""" """OpenAI /v1/images/generations and /v1/images/edits compatible request.
Generation params (seed, guidance_scale, num_inference_steps, negative_prompt)
are specified under ``nvext``. SGLang-specific defaults (guidance_scale=7.5,
num_inference_steps=50) are applied in the handler, not the model.
"""
prompt: str prompt: str
model: str # e.g. "stabilityai/stable-diffusion-3.5-medium" model: str # e.g. "stabilityai/stable-diffusion-3.5-medium"
...@@ -163,9 +159,9 @@ class CreateImageRequest(BaseModel): ...@@ -163,9 +159,9 @@ class CreateImageRequest(BaseModel):
quality: Optional[str] = "standard" # standard, hd quality: Optional[str] = "standard" # standard, hd
response_format: Optional[str] = "url" # url or b64_json response_format: Optional[str] = "url" # url or b64_json
user: Optional[str] = None user: Optional[str] = None
input_reference: Optional[str] = None # For I2I/TI2I - image path/url
# NVIDIA extensions nested under nvext nvext: Optional[ImageNvExt] = None
nvext: Optional[NvExt] = None
class ImageData(BaseModel): class ImageData(BaseModel):
......
...@@ -14,16 +14,19 @@ import torch ...@@ -14,16 +14,19 @@ import torch
from PIL import Image from PIL import Image
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.protocols.image_protocol import ImageNvExt
from dynamo.common.storage import upload_to_fs from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.otel_tracing import build_trace_headers from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseGenerativeHandler from dynamo.sglang.request_handlers.handler_base import BaseGenerativeHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_NUM_INFERENCE_STEPS = 50 MAX_NUM_INFERENCE_STEPS = 50
DEFAULT_NUM_INFERENCE_STEPS = 50
DEFAULT_GUIDANCE_SCALE = 7.5
class ImageDiffusionWorkerHandler(BaseGenerativeHandler): class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
...@@ -92,11 +95,17 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -92,11 +95,17 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
try: try:
req = CreateImageRequest(**request) req = CreateImageRequest(**request)
# get extra parameters nvext = req.nvext or ImageNvExt()
nvext = req.nvext or NvExt()
nvext.num_inference_steps = min( # Apply SGLang-specific defaults for unset values
nvext.num_inference_steps or 50, MAX_NUM_INFERENCE_STEPS raw_steps = nvext.num_inference_steps or DEFAULT_NUM_INFERENCE_STEPS
) if raw_steps > MAX_NUM_INFERENCE_STEPS:
logger.warning(
f"num_inference_steps={raw_steps} exceeds max "
f"{MAX_NUM_INFERENCE_STEPS}, clamping"
)
num_inference_steps = min(raw_steps, MAX_NUM_INFERENCE_STEPS)
guidance_scale = nvext.guidance_scale or DEFAULT_GUIDANCE_SCALE
width, height = self._parse_size(req.size) width, height = self._parse_size(req.size)
...@@ -105,9 +114,10 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -105,9 +114,10 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
negative_prompt=nvext.negative_prompt, negative_prompt=nvext.negative_prompt,
width=width, width=width,
height=height, height=height,
num_inference_steps=nvext.num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=nvext.guidance_scale, guidance_scale=guidance_scale,
seed=nvext.seed, seed=nvext.seed,
input_reference=req.input_reference,
) )
context_id = context.id() context_id = context.id()
...@@ -145,6 +155,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -145,6 +155,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
guidance_scale: float, guidance_scale: float,
seed: Optional[int], seed: Optional[int],
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
input_reference: Optional[str] = None,
) -> list[bytes]: ) -> list[bytes]:
"""Generate images using SGLang DiffGenerator""" """Generate images using SGLang DiffGenerator"""
args = { args = {
...@@ -155,8 +166,15 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -155,8 +166,15 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"save_output": False, # We handle saving ourselves "save_output": False, # We handle saving ourselves
"guidance_scale": guidance_scale, "guidance_scale": guidance_scale,
"seed": seed if seed else random.randint(0, 1000000), "seed": seed if seed is not None else random.randint(0, 1000000),
} }
# Add image_path for I2I/TI2I if provided
if input_reference is not None:
if not input_reference.strip():
raise ValueError("input_reference must be a non-empty string")
args["image_path"] = input_reference
result = await asyncio.to_thread( result = await asyncio.to_thread(
self.generator.generate, self.generator.generate,
sampling_params_kwargs=args, sampling_params_kwargs=args,
...@@ -175,7 +193,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -175,7 +193,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
for img in images: for img in images:
if isinstance(img, bytes): if isinstance(img, bytes):
image_bytes_list.append(img) image_bytes_list.append(img)
elif Image is not None and isinstance(img, Image.Image): elif isinstance(img, Image.Image):
# Convert PIL Image to bytes # Convert PIL Image to bytes
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format="PNG") img.save(buf, format="PNG")
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import base64 import base64
import io import io
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
from PIL import Image from PIL import Image
...@@ -347,7 +347,7 @@ class TestImageDiffusionWorkerHandler: ...@@ -347,7 +347,7 @@ class TestImageDiffusionWorkerHandler:
"""Test that nvext parameters are passed to the generator.""" """Test that nvext parameters are passed to the generator."""
test_image = Image.new("RGB", (256, 256), color="yellow") test_image = Image.new("RGB", (256, 256), color="yellow")
handler._generate_images = Mock(return_value=[test_image.tobytes()]) handler._generate_images = AsyncMock(return_value=[test_image.tobytes()])
request = { request = {
"prompt": "A yellow square", "prompt": "A yellow square",
...@@ -382,4 +382,59 @@ class TestImageDiffusionWorkerHandler: ...@@ -382,4 +382,59 @@ class TestImageDiffusionWorkerHandler:
guidance_scale=7.5, guidance_scale=7.5,
seed=42, seed=42,
negative_prompt="negative", negative_prompt="negative",
input_reference=None,
) )
@pytest.mark.asyncio
async def test_generate_i2i_passes_image_path(
self, handler, mock_context, tmp_path
):
"""Test that input_reference is passed as image_path to the generator."""
test_image = Image.new("RGB", (256, 256), color="green")
handler.generator.generate = Mock(
return_value=SimpleNamespace(frames=[test_image])
)
input_ref = str(tmp_path / "test_input.png")
request = {
"prompt": "Transform this image",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"input_reference": input_ref,
}
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify image_path was passed to the generator
call_args = handler.generator.generate.call_args
sampling_params = call_args[1]["sampling_params_kwargs"]
assert sampling_params["image_path"] == input_ref
@pytest.mark.asyncio
async def test_generate_t2i_no_image_path(self, handler, mock_context):
"""Test that image_path is NOT passed when input_reference is absent."""
test_image = Image.new("RGB", (256, 256), color="red")
handler.generator.generate = Mock(
return_value=SimpleNamespace(frames=[test_image])
)
request = {
"prompt": "A red square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
}
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify image_path was NOT passed
call_args = handler.generator.generate.call_args
sampling_params = call_args[1]["sampling_params_kwargs"]
assert "image_path" not in sampling_params
...@@ -2029,20 +2029,44 @@ async fn images( ...@@ -2029,20 +2029,44 @@ async fn images(
Ok(Json(response).into_response()) Ok(Json(response).into_response())
} }
/// Create an Axum [`Router`] for the OpenAI API Images endpoint /// Handler for `/v1/images/edits` (I2I). Requires `input_reference`.
/// If not path is provided, the default path is `/v1/images/generations` async fn images_edits(
state: State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateImageRequest>,
) -> Result<Response, ErrorResponse> {
if request.input_reference.is_none() {
let code = StatusCode::BAD_REQUEST;
return Err((
code,
Json(ErrorMessage {
message: "input_reference is required for /v1/images/edits".to_string(),
error_type: map_error_code_to_error_type(code),
code: code.as_u16(),
}),
));
}
images(state, headers, Json(request)).await
}
/// Create an Axum [`Router`] for the OpenAI API Images endpoints.
/// `/v1/images/generations` accepts optional `input_reference` (T2I or TI2I).
/// `/v1/images/edits` requires `input_reference` (I2I).
pub fn images_router( pub fn images_router(
state: Arc<service_v2::State>, state: Arc<service_v2::State>,
path: Option<String>, path: Option<String>,
) -> (Vec<RouteDoc>, Router) { ) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/images/generations".to_string()); let generations_path = path.unwrap_or("/v1/images/generations".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path); let edits_path = generations_path.replace("/generations", "/edits");
let doc = RouteDoc::new(axum::http::Method::POST, &generations_path);
let edits_doc = RouteDoc::new(axum::http::Method::POST, &edits_path);
let router = Router::new() let router = Router::new()
.route(&path, post(images)) .route(&generations_path, post(images))
.route(&edits_path, post(images_edits))
.layer(middleware::from_fn(smart_json_error_middleware)) .layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state); .with_state(state);
(vec![doc], router) (vec![doc, edits_doc], router)
} }
async fn videos( async fn videos(
......
...@@ -11,11 +11,16 @@ mod nvext; ...@@ -11,11 +11,16 @@ mod nvext;
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use nvext::{NvExt, NvExtProvider}; pub use nvext::{NvExt, NvExtProvider};
/// Image generation request with NVIDIA extensions.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateImageRequest { pub struct NvCreateImageRequest {
#[serde(flatten)] #[serde(flatten)]
pub inner: dynamo_protocols::types::CreateImageRequest, pub inner: dynamo_protocols::types::CreateImageRequest,
/// Optional image reference that guides generation (for I2I/TI2I).
#[serde(skip_serializing_if = "Option::is_none")]
pub input_reference: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
} }
......
...@@ -37,9 +37,10 @@ pub struct NvExt { ...@@ -37,9 +37,10 @@ pub struct NvExt {
pub guidance_scale: Option<f32>, pub guidance_scale: Option<f32>,
/// The seed for the random number generator. /// The seed for the random number generator.
/// i64 to match PyTorch's torch.manual_seed() accepted range.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub seed: Option<u32>, pub seed: Option<i64>,
} }
impl Default for NvExt { impl Default for NvExt {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment