"vscode:/vscode.git/clone" did not exist on "a671dbd4239ef8715229f432da0a9b148b23ae9b"
Unverified Commit 225dad4d authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: add I2I (image-to-image) support for SGLang diffusion backend (#7870)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent ba2185a0
......@@ -65,6 +65,9 @@ class NvCreateImageRequest(BaseModel):
moderation: Optional[str] = None
"""Content moderation level: auto or low."""
input_reference: Optional[str] = None
"""Optional image reference that guides generation (for I2I)."""
nvext: Optional[ImageNvExt] = None
"""NVIDIA extensions."""
......
......@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from dynamo.common.multimodal import TransferRequest
from dynamo.common.protocols.image_protocol import ImageNvExt
TokenIdType = int
......@@ -143,18 +144,13 @@ class DisaggSglangMultimodalRequest(BaseModel):
# ============================================================================
class NvExt(BaseModel):
"""NVIDIA extensions for image generation"""
negative_prompt: Optional[str] = None
num_inference_steps: Optional[int] = 50
guidance_scale: float = 7.5
seed: Optional[int] = None
annotations: Optional[list[str]] = None
class CreateImageRequest(BaseModel):
"""OpenAI /v1/images/generations compatible request"""
"""OpenAI /v1/images/generations and /v1/images/edits compatible request.
Generation params (seed, guidance_scale, num_inference_steps, negative_prompt)
are specified under ``nvext``. SGLang-specific defaults (guidance_scale=7.5,
num_inference_steps=50) are applied in the handler, not the model.
"""
prompt: str
model: str # e.g. "stabilityai/stable-diffusion-3.5-medium"
......@@ -163,9 +159,9 @@ class CreateImageRequest(BaseModel):
quality: Optional[str] = "standard" # standard, hd
response_format: Optional[str] = "url" # url or b64_json
user: Optional[str] = None
input_reference: Optional[str] = None # For I2I/TI2I - image path/url
# NVIDIA extensions nested under nvext
nvext: Optional[NvExt] = None
nvext: Optional[ImageNvExt] = None
class ImageData(BaseModel):
......
......@@ -14,16 +14,19 @@ import torch
from PIL import Image
from dynamo._core import Context
from dynamo.common.protocols.image_protocol import ImageNvExt
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseGenerativeHandler
logger = logging.getLogger(__name__)
MAX_NUM_INFERENCE_STEPS = 50
DEFAULT_NUM_INFERENCE_STEPS = 50
DEFAULT_GUIDANCE_SCALE = 7.5
class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
......@@ -92,11 +95,17 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
try:
req = CreateImageRequest(**request)
# get extra parameters
nvext = req.nvext or NvExt()
nvext.num_inference_steps = min(
nvext.num_inference_steps or 50, MAX_NUM_INFERENCE_STEPS
)
nvext = req.nvext or ImageNvExt()
# Apply SGLang-specific defaults for unset values
raw_steps = nvext.num_inference_steps or DEFAULT_NUM_INFERENCE_STEPS
if raw_steps > MAX_NUM_INFERENCE_STEPS:
logger.warning(
f"num_inference_steps={raw_steps} exceeds max "
f"{MAX_NUM_INFERENCE_STEPS}, clamping"
)
num_inference_steps = min(raw_steps, MAX_NUM_INFERENCE_STEPS)
guidance_scale = nvext.guidance_scale or DEFAULT_GUIDANCE_SCALE
width, height = self._parse_size(req.size)
......@@ -105,9 +114,10 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
negative_prompt=nvext.negative_prompt,
width=width,
height=height,
num_inference_steps=nvext.num_inference_steps,
guidance_scale=nvext.guidance_scale,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=nvext.seed,
input_reference=req.input_reference,
)
context_id = context.id()
......@@ -145,6 +155,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
guidance_scale: float,
seed: Optional[int],
negative_prompt: Optional[str] = None,
input_reference: Optional[str] = None,
) -> list[bytes]:
"""Generate images using SGLang DiffGenerator"""
args = {
......@@ -155,8 +166,15 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
"num_inference_steps": num_inference_steps,
"save_output": False, # We handle saving ourselves
"guidance_scale": guidance_scale,
"seed": seed if seed else random.randint(0, 1000000),
"seed": seed if seed is not None else random.randint(0, 1000000),
}
# Add image_path for I2I/TI2I if provided
if input_reference is not None:
if not input_reference.strip():
raise ValueError("input_reference must be a non-empty string")
args["image_path"] = input_reference
result = await asyncio.to_thread(
self.generator.generate,
sampling_params_kwargs=args,
......@@ -175,7 +193,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
for img in images:
if isinstance(img, bytes):
image_bytes_list.append(img)
elif Image is not None and isinstance(img, Image.Image):
elif isinstance(img, Image.Image):
# Convert PIL Image to bytes
buf = io.BytesIO()
img.save(buf, format="PNG")
......
......@@ -6,7 +6,7 @@
import base64
import io
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from PIL import Image
......@@ -347,7 +347,7 @@ class TestImageDiffusionWorkerHandler:
"""Test that nvext parameters are passed to the generator."""
test_image = Image.new("RGB", (256, 256), color="yellow")
handler._generate_images = Mock(return_value=[test_image.tobytes()])
handler._generate_images = AsyncMock(return_value=[test_image.tobytes()])
request = {
"prompt": "A yellow square",
......@@ -382,4 +382,59 @@ class TestImageDiffusionWorkerHandler:
guidance_scale=7.5,
seed=42,
negative_prompt="negative",
input_reference=None,
)
@pytest.mark.asyncio
async def test_generate_i2i_passes_image_path(
self, handler, mock_context, tmp_path
):
"""Test that input_reference is passed as image_path to the generator."""
test_image = Image.new("RGB", (256, 256), color="green")
handler.generator.generate = Mock(
return_value=SimpleNamespace(frames=[test_image])
)
input_ref = str(tmp_path / "test_input.png")
request = {
"prompt": "Transform this image",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"input_reference": input_ref,
}
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify image_path was passed to the generator
call_args = handler.generator.generate.call_args
sampling_params = call_args[1]["sampling_params_kwargs"]
assert sampling_params["image_path"] == input_ref
@pytest.mark.asyncio
async def test_generate_t2i_no_image_path(self, handler, mock_context):
"""Test that image_path is NOT passed when input_reference is absent."""
test_image = Image.new("RGB", (256, 256), color="red")
handler.generator.generate = Mock(
return_value=SimpleNamespace(frames=[test_image])
)
request = {
"prompt": "A red square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
}
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify image_path was NOT passed
call_args = handler.generator.generate.call_args
sampling_params = call_args[1]["sampling_params_kwargs"]
assert "image_path" not in sampling_params
......@@ -2029,20 +2029,44 @@ async fn images(
Ok(Json(response).into_response())
}
/// Create an Axum [`Router`] for the OpenAI API Images endpoint
/// If not path is provided, the default path is `/v1/images/generations`
/// Handler for `/v1/images/edits` (I2I). Requires `input_reference`.
async fn images_edits(
state: State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateImageRequest>,
) -> Result<Response, ErrorResponse> {
if request.input_reference.is_none() {
let code = StatusCode::BAD_REQUEST;
return Err((
code,
Json(ErrorMessage {
message: "input_reference is required for /v1/images/edits".to_string(),
error_type: map_error_code_to_error_type(code),
code: code.as_u16(),
}),
));
}
images(state, headers, Json(request)).await
}
/// Create an Axum [`Router`] for the OpenAI API Images endpoints.
/// `/v1/images/generations` accepts optional `input_reference` (T2I or TI2I).
/// `/v1/images/edits` requires `input_reference` (I2I).
pub fn images_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/images/generations".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let generations_path = path.unwrap_or("/v1/images/generations".to_string());
let edits_path = generations_path.replace("/generations", "/edits");
let doc = RouteDoc::new(axum::http::Method::POST, &generations_path);
let edits_doc = RouteDoc::new(axum::http::Method::POST, &edits_path);
let router = Router::new()
.route(&path, post(images))
.route(&generations_path, post(images))
.route(&edits_path, post(images_edits))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
(vec![doc, edits_doc], router)
}
async fn videos(
......
......@@ -11,11 +11,16 @@ mod nvext;
pub use aggregator::DeltaAggregator;
pub use nvext::{NvExt, NvExtProvider};
/// Image generation request with NVIDIA extensions.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateImageRequest {
#[serde(flatten)]
pub inner: dynamo_protocols::types::CreateImageRequest,
/// Optional image reference that guides generation (for I2I/TI2I).
#[serde(skip_serializing_if = "Option::is_none")]
pub input_reference: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
......
......@@ -37,9 +37,10 @@ pub struct NvExt {
pub guidance_scale: Option<f32>,
/// The seed for the random number generator.
/// i64 to match PyTorch's torch.manual_seed() accepted range.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub seed: Option<u32>,
pub seed: Option<i64>,
}
impl Default for NvExt {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment