Unverified Commit 7e970d44 authored by Konrad Nowicki's avatar Konrad Nowicki Committed by GitHub
Browse files

feat: image diffusion with SGLang diffusion (#5609)


Signed-off-by: default avatarKonrad Nowicki <knowicki@nvidia.com>
Co-authored-by: default avatardagil-nvidia <dagil@nvidia.com>
parent f3aa1e01
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Dynamo Common Storage Module
Filesystem Spec (fsspec) is a project to provide a unified pythonic interface to local, remote and embedded file systems and bytes storage.
https://filesystem-spec.readthedocs.io/en/latest/index.html#who-uses-fsspec
Configuration for the storage:
Local Filesystem:
1. fs_url MUST contain a root path - path must be accessible and writable
S3:
1. If you want to use S3 please install additional dependencies: fsspec[s3]
2. fs_url MUST contain a bucket name
3. Configure credentials https://s3fs.readthedocs.io/en/latest/?badge=latest#credentials
a) AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables
b) configuration files such as ~/.aws/credentials
c) for nodes on EC2, the IAM metadata provider
d) for S3 compatible storage, you can use the following environment variables:
# export FSSPEC_S3_ENDPOINT_URL=https://...
# export FSSPEC_S3_KEY='miniokey...'
# export FSSPEC_S3_SECRET='asecretkey...'
"""
import fsspec
from fsspec.implementations.dirfs import DirFileSystem
def get_fs(fs_url: str) -> DirFileSystem:
"""
Initialize fsspec filesystem for the given URL.
Args:
fs_url: The URL of the filesystem to initialize. e.g. s3://bucket, gs://bucket, file:///local/path
Returns:
The initialized DirFileSystem wrapper for the filesystem.
fs.fs.protocol to get the protocol of the filesystem
fs.path to get the bucket or root path
path to the object in the filesystem - f"{fs.fs.protocol}://{fs.path}/{path}"
"""
# Extract protocol from URL (s3://, gs://, az://, file://)
fs_url_parts = fs_url.split("://")
protocol = fs_url_parts[0] if "://" in fs_url else "file"
# ... or bucket name
root_path = fs_url_parts[1] if len(fs_url_parts) > 1 else "/"
fs_opts = {}
if protocol in "file":
# create directory for local filesystem
fs_opts = {"auto_mkdir": True}
return DirFileSystem(fs=fsspec.filesystem(protocol, **fs_opts), path=root_path)
...@@ -131,6 +131,26 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -131,6 +131,26 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": os.environ.get("DYN_LOCAL_INDEXER", "false"), "default": os.environ.get("DYN_LOCAL_INDEXER", "false"),
"help": "Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).", "help": "Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
}, },
"image-diffusion-worker": {
"flags": ["--image-diffusion-worker"],
"action": "store_true",
"default": False,
"help": "Run as image diffusion worker for image generation",
},
"image-diffusion-fs-url": {
"flags": ["--image-diffusion-fs-url"],
"type": str,
"default": None,
"help": "Filesystem URL for storing generated images using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.",
},
"image-diffusion-base-url": {
"flags": ["--image-diffusion-base-url"],
"type": str,
"default": os.environ.get(
"DYN_IMAGE_DIFFUSION_BASE_URL", "http://localhost:8008/"
),
"help": "Base URL for rewriting image URLs in responses (e.g., http://localhost:8008/). When set, generated image URLs will use this base instead of filesystem URLs. Can be set via DYN_IMAGE_DIFFUSION_URL_BASE env var.",
},
} }
...@@ -173,6 +193,11 @@ class DynamoArgs: ...@@ -173,6 +193,11 @@ class DynamoArgs:
# Whether to enable NATS for KV events (derived from server_args.kv_events_config) # Whether to enable NATS for KV events (derived from server_args.kv_events_config)
use_kv_events: bool = False use_kv_events: bool = False
# image diffusion options
image_diffusion_worker: bool = False
image_diffusion_fs_url: Optional[str] = None
image_diffusion_base_url: Optional[str] = None
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
AGGREGATED = "agg" AGGREGATED = "agg"
...@@ -443,6 +468,8 @@ async def parse_args(args: list[str]) -> Config: ...@@ -443,6 +468,8 @@ async def parse_args(args: list[str]) -> Config:
if endpoint is None: if endpoint is None:
if parsed_args.embedding_worker: if parsed_args.embedding_worker:
endpoint = f"dyn://{namespace}.backend.generate" endpoint = f"dyn://{namespace}.backend.generate"
elif getattr(parsed_args, "image_diffusion_worker", False):
endpoint = f"dyn://{namespace}.backend.generate"
elif ( elif (
hasattr(parsed_args, "disaggregation_mode") hasattr(parsed_args, "disaggregation_mode")
and parsed_args.disaggregation_mode == "prefill" and parsed_args.disaggregation_mode == "prefill"
...@@ -521,7 +548,36 @@ async def parse_args(args: list[str]) -> Config: ...@@ -521,7 +548,36 @@ async def parse_args(args: list[str]) -> Config:
# TODO: sglang downloads the model in `from_cli_args`, which means we had to # TODO: sglang downloads the model in `from_cli_args`, which means we had to
# fetch_llm (download the model) here, in `parse_args`. `parse_args` should not # fetch_llm (download the model) here, in `parse_args`. `parse_args` should not
# contain code to download a model, it should only parse the args. # contain code to download a model, it should only parse the args.
server_args = ServerArgs.from_cli_args(parsed_args)
# For diffusion workers, create a minimal dummy ServerArgs since diffusion
# doesn't use transformer models or sglang Engine - it uses DiffGenerator directly
image_diffusion_worker = getattr(parsed_args, "image_diffusion_worker", False)
if image_diffusion_worker:
logging.info(f"Image diffusion worker detected with model: {model_path}")
# Need to use ServerArgs not intended for sglang[diffusion], multimodal_gen has its own ServerArgs.
server_args = ServerArgs("none") # HACK: Avoid triggering __post_init__
server_args.model_path = model_path
server_args.served_model_name = parsed_args.served_model_name
server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False)
server_args.log_level = getattr(parsed_args, "log_level", "info")
server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None)
server_args.speculative_algorithm = None
server_args.disaggregation_mode = None
server_args.dllm_algorithm = False
server_args.tp_size = getattr(parsed_args, "tensor_parallel_size", 1)
server_args.dp_size = getattr(parsed_args, "data_parallel_size", 1)
parsed_args.use_sglang_tokenizer = True
parsed_args.dyn_endpoint_types = "images"
logging.info(
f"Created stub ServerArgs for diffusion: model_path={server_args.model_path}"
)
else:
server_args = ServerArgs.from_cli_args(parsed_args)
# Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new # Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new
# tokens since last output), not cumulative tokens. When stream_output=True, # tokens since last output), not cumulative tokens. When stream_output=True,
...@@ -576,6 +632,9 @@ async def parse_args(args: list[str]) -> Config: ...@@ -576,6 +632,9 @@ async def parse_args(args: list[str]) -> Config:
multimodal_worker=parsed_args.multimodal_worker, multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker, embedding_worker=parsed_args.embedding_worker,
diffusion_worker=diffusion_worker, diffusion_worker=diffusion_worker,
image_diffusion_worker=getattr(parsed_args, "image_diffusion_worker", False),
image_diffusion_fs_url=getattr(parsed_args, "image_diffusion_fs_url", None),
image_diffusion_base_url=getattr(parsed_args, "image_diffusion_base_url", None),
dump_config_to=parsed_args.dump_config_to, dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true", enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
......
...@@ -118,3 +118,29 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -118,3 +118,29 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore
super().__init__() super().__init__()
class ImageDiffusionHealthCheckPayload(HealthCheckPayload):
"""Image diffusion-specific health check payload for image generation workers.
Sends a minimal image generation request to verify the diffusion worker
is responding and the model is loaded. Uses minimal resources for fast checks.
"""
def __init__(self, model_path: str):
"""Initialize diffusion health check payload with minimal generation request.
Args:
model_path: The diffusion model being served.
"""
self.default_payload = {
"prompt": "test", # Minimal prompt
"model": model_path,
"n": 1, # Generate 1 image
"size": "512x512", # Small size for fast health check
"num_inference_steps": 1, # Just 1 step (fast but low quality)
"guidance_scale": 7.5, # Standard guidance scale
"response_format": "b64_json", # Don't require S3 for health check
}
super().__init__()
...@@ -11,12 +11,14 @@ import sglang as sgl ...@@ -11,12 +11,14 @@ import sglang as sgl
import uvloop import uvloop
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.llm import ModelInput, ModelType from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import ( from dynamo.sglang.health_check import (
ImageDiffusionHealthCheckPayload,
SglangHealthCheckPayload, SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload, SglangPrefillHealthCheckPayload,
) )
...@@ -25,11 +27,15 @@ from dynamo.sglang.publisher import ( ...@@ -25,11 +27,15 @@ from dynamo.sglang.publisher import (
setup_prometheus_registry, setup_prometheus_registry,
setup_sgl_metrics, setup_sgl_metrics,
) )
from dynamo.sglang.register import register_llm_with_readiness_gate from dynamo.sglang.register import (
register_image_diffusion_model,
register_llm_with_readiness_gate,
)
from dynamo.sglang.request_handlers import ( from dynamo.sglang.request_handlers import (
DecodeWorkerHandler, DecodeWorkerHandler,
DiffusionWorkerHandler, DiffusionWorkerHandler,
EmbeddingWorkerHandler, EmbeddingWorkerHandler,
ImageDiffusionWorkerHandler,
MultimodalEncodeWorkerHandler, MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler, MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler, MultimodalProcessorHandler,
...@@ -113,7 +119,9 @@ async def worker(): ...@@ -113,7 +119,9 @@ async def worker():
logging.info("Signal handlers will trigger a graceful shutdown of the runtime") logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
if config.dynamo_args.embedding_worker: if config.dynamo_args.image_diffusion_worker:
await init_image_diffusion(runtime, config)
elif config.dynamo_args.embedding_worker:
await init_embedding(runtime, config) await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor: elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config) await init_multimodal_processor(runtime, config)
...@@ -128,6 +136,7 @@ async def worker(): ...@@ -128,6 +136,7 @@ async def worker():
await init_diffusion(runtime, config) await init_diffusion(runtime, config)
elif config.serving_mode != DisaggregationMode.PREFILL: elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config) await init(runtime, config)
else: else:
await init_prefill(runtime, config) await init_prefill(runtime, config)
...@@ -446,6 +455,87 @@ async def init_embedding(runtime: DistributedRuntime, config: Config): ...@@ -446,6 +455,87 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
handler.cleanup() handler.cleanup()
async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
"""Initialize image diffusion worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
# Initialize DiffGenerator (not sgl.Engine)
from sglang.multimodal_gen import DiffGenerator
if not server_args.model_path:
raise ValueError("--model is required for diffusion workers")
# Parallelism configuration
tp_size = getattr(server_args, "tp_size", 1)
dp_size = getattr(server_args, "dp_size", 1)
num_gpus = tp_size * dp_size
# Distributed configuration
dist_timeout = getattr(server_args, "dist_timeout", None)
generator = DiffGenerator.from_pretrained(
model_path=server_args.model_path,
# Parallelism configuration
num_gpus=num_gpus,
tp_size=tp_size,
dp_size=dp_size,
# Distributed configuration
dist_timeout=dist_timeout,
)
# Initialize fsspec filesystems for image storage
fs_url = dynamo_args.image_diffusion_fs_url
# Initialize primary filesystem
if not fs_url:
raise ValueError("--image-diffusion-fs-url is required for diffusion workers")
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later
handler = ImageDiffusionWorkerHandler(
component,
generator,
config,
publisher=None,
fs=get_fs(fs_url),
)
# Create proper health check payload that sends a minimal diffusion request
health_check_payload = ImageDiffusionHealthCheckPayload(
model_path=server_args.model_path
).to_dict()
ready_event = asyncio.Event()
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[], # No LLM metrics labels
health_check_payload=health_check_payload,
),
register_image_diffusion_model(
generator,
generate_endpoint,
server_args,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve image diffusion endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config): async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal processor component""" """Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
......
...@@ -131,3 +131,46 @@ class DisaggSglangMultimodalRequest(BaseModel): ...@@ -131,3 +131,46 @@ class DisaggSglangMultimodalRequest(BaseModel):
request: SglangMultimodalRequest request: SglangMultimodalRequest
sampling_params: dict sampling_params: dict
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# ============================================================================
# Image diffusion Protocol Types
# ============================================================================
class NvExt(BaseModel):
"""NVIDIA extensions for image generation"""
negative_prompt: Optional[str] = None
num_inference_steps: Optional[int] = 50
guidance_scale: float = 7.5
seed: Optional[int] = None
annotations: Optional[list[str]] = None
class CreateImageRequest(BaseModel):
"""OpenAI /v1/images/generations compatible request"""
prompt: str
model: str # e.g. "stabilityai/stable-diffusion-3.5-medium"
n: int = 1 # Number of images
size: Optional[str] = "1024x1024" # "WxH" format
quality: Optional[str] = "standard" # standard, hd
response_format: Optional[str] = "url" # url or b64_json
user: Optional[str] = None
# NVIDIA extensions nested under nvext
nvext: Optional[NvExt] = None
class ImageData(BaseModel):
url: Optional[str] = None # S3 URL
b64_json: Optional[str] = None # Base64 encoded
revised_prompt: Optional[str] = None
class ImagesResponse(BaseModel):
"""OpenAI-compatible response"""
created: int # Unix timestamp
data: list[ImageData]
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import asyncio import asyncio
import logging import logging
import socket import socket
from typing import Optional from typing import Any, Optional
import sglang as sgl import sglang as sgl
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -269,3 +269,43 @@ async def register_llm_with_readiness_gate( ...@@ -269,3 +269,43 @@ async def register_llm_with_readiness_gate(
readiness_gate.set() readiness_gate.set()
logging.info("Model registration succeeded; processing queued requests") logging.info("Model registration succeeded; processing queued requests")
async def register_image_diffusion_model(
generator: Any, # DiffGenerator
endpoint: Endpoint,
server_args: ServerArgs,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Register diffusion model with Dynamo runtime.
Args:
generator: The SGLang DiffGenerator instance.
endpoint: The Dynamo endpoint for generation requests.
server_args: SGLang server configuration.
readiness_gate: Optional event to signal when registration completes.
Note:
Image diffusion models use ModelInput.Text (text prompts) and ModelType.Images.
"""
# Use model_path as the model name (diffusion workers don't have served_model_name)
model_name = server_args.model_path
try:
await register_llm(
ModelInput.Text,
ModelType.Images,
endpoint,
model_name,
model_name,
)
logging.info(f"Successfully registered diffusion model: {model_name}")
except Exception as e:
logging.error(f"Failed to register diffusion model: {e}")
raise RuntimeError("Image diffusion model registration failed")
# Signal readiness
if readiness_gate:
readiness_gate.set()
logging.info(f"Image diffusion model ready: {model_name}")
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
from .embedding import EmbeddingWorkerHandler from .embedding import EmbeddingWorkerHandler
# Base handlers # Base handlers
from .handler_base import BaseWorkerHandler from .handler_base import BaseGenerativeHandler, BaseWorkerHandler
# Image diffusion handlers
from .image_diffusion import ImageDiffusionWorkerHandler
# LLM handlers # LLM handlers
from .llm import DecodeWorkerHandler, DiffusionWorkerHandler, PrefillWorkerHandler from .llm import DecodeWorkerHandler, DiffusionWorkerHandler, PrefillWorkerHandler
...@@ -19,6 +22,8 @@ from .multimodal import ( ...@@ -19,6 +22,8 @@ from .multimodal import (
) )
__all__ = [ __all__ = [
# Base handlers
"BaseGenerativeHandler",
"BaseWorkerHandler", "BaseWorkerHandler",
# LLM handlers # LLM handlers
"DecodeWorkerHandler", "DecodeWorkerHandler",
...@@ -26,6 +31,8 @@ __all__ = [ ...@@ -26,6 +31,8 @@ __all__ = [
"PrefillWorkerHandler", "PrefillWorkerHandler",
# Embedding handlers # Embedding handlers
"EmbeddingWorkerHandler", "EmbeddingWorkerHandler",
# Image diffusion handlers
"ImageDiffusionWorkerHandler",
# Multimodal handlers # Multimodal handlers
"MultimodalEncodeWorkerHandler", "MultimodalEncodeWorkerHandler",
"MultimodalPrefillWorkerHandler", "MultimodalPrefillWorkerHandler",
......
...@@ -18,8 +18,82 @@ from dynamo.sglang.args import Config ...@@ -18,8 +18,82 @@ from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
class BaseWorkerHandler(ABC): class BaseGenerativeHandler(ABC):
"""Abstract base class for SGLang worker handlers.""" """Minimal base class for all generative handlers (LLM, diffusion, etc.).
Provides common infrastructure for:
- Component and configuration management
- Metrics and KV event publishing
- Distributed tracing integration
"""
def __init__(
self,
component: Component,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
) -> None:
"""Initialize base generative handler.
Args:
component: The Dynamo runtime component.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
"""
self.component = component
self.config = config
# Set up metrics and KV publishers
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
@abstractmethod
async def generate(
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate response from request.
Args:
request: Request dict with input and parameters.
context: Context object for cancellation handling.
Yields:
Response data (format varies by handler implementation).
"""
pass
def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed."""
pass
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to generation functions.
Args:
context: Dynamo Context object containing trace information.
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
class BaseWorkerHandler(BaseGenerativeHandler):
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
- SGLang Engine integration
- Tokenization and input parameter management
- Disaggregated serving support
"""
def __init__( def __init__(
self, self,
...@@ -38,7 +112,10 @@ class BaseWorkerHandler(ABC): ...@@ -38,7 +112,10 @@ class BaseWorkerHandler(ABC):
publisher: Optional metrics publisher for the worker. publisher: Optional metrics publisher for the worker.
generate_endpoint: The endpoint handle for discovery registration. generate_endpoint: The endpoint handle for discovery registration.
""" """
self.component = component # Call parent constructor
super().__init__(component, config, publisher)
# LLM-specific initialization
self.engine = engine self.engine = engine
self.config = config self.config = config
self.generate_endpoint = generate_endpoint self.generate_endpoint = generate_endpoint
...@@ -274,21 +351,6 @@ class BaseWorkerHandler(ABC): ...@@ -274,21 +351,6 @@ class BaseWorkerHandler(ABC):
return bootstrap_host, bootstrap_port return bootstrap_host, bootstrap_port
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to SGLang's external_trace_header parameter.
Args:
context: Dynamo Context object containing trace information.
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
async def _handle_cancellation( async def _handle_cancellation(
self, request_id_future: asyncio.Future, context: Context self, request_id_future: asyncio.Future, context: Context
): ):
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .image_diffusion_handler import ImageDiffusionWorkerHandler
__all__ = ["ImageDiffusionWorkerHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import io
import logging
import random
import time
import uuid
from typing import Any, AsyncGenerator, Optional
import torch
from PIL import Image
from dynamo._core import Component, Context
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseGenerativeHandler
logger = logging.getLogger(__name__)
MAX_NUM_INFERENCE_STEPS = 50
class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
"""Handler for diffusion image generation.
Inherits from BaseGenerativeHandler for common infrastructure like
tracing, metrics publishing
"""
def __init__(
self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
fs: Any = None, # fsspec.AbstractFileSystem for primary storage
):
"""Initialize diffusion worker handler.
Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for diffusion currently).
fs: Optional fsspec filesystem for primary image storage.
"""
super().__init__(component, config, publisher)
self.generator = generator # DiffGenerator, not Engine
self.fs = fs
self.fs_url = config.dynamo_args.image_diffusion_fs_url
self.base_url = config.dynamo_args.image_diffusion_base_url
logger.info(
f"Image diffusion worker handler initialized with fs_url={self.fs_url}, url_base={self.base_url}"
)
def cleanup(self) -> None:
"""Cleanup generator resources"""
if self.generator is not None:
del self.generator
torch.cuda.empty_cache()
logger.info("Image diffusion generator cleanup complete")
super().cleanup()
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""
Generate image(s) from text prompt.
Unlike LLM streaming, diffusion returns complete image(s) at end.
Args:
request: Request dict with prompt and generation parameters.
context: Context object for cancellation handling.
Yields:
Response dict with generated images (OpenAI-compatible format).
"""
logger.info(f"Image diffusion request: {request}")
# Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context)
if trace_header:
logger.debug(f"Image diffusion request with trace: {trace_header}")
try:
req = CreateImageRequest(**request)
# get extra parameters
nvext = req.nvext or NvExt()
nvext.num_inference_steps = min(
nvext.num_inference_steps or 50, MAX_NUM_INFERENCE_STEPS
)
width, height = self._parse_size(req.size)
images = await self._generate_images(
prompt=req.prompt,
negative_prompt=nvext.negative_prompt,
width=width,
height=height,
num_inference_steps=nvext.num_inference_steps,
guidance_scale=nvext.guidance_scale,
seed=nvext.seed,
)
user_id = req.user if req.user else context.id()
image_data = []
for img in images:
# uploading or encoding the image
if req.response_format == "url":
url = await self._upload_to_fs(img, user_id, context.id())
image_data.append(ImageData(url=url))
else:
b64 = self._encode_base64(img)
image_data.append(ImageData(b64_json=b64))
response = ImagesResponse(created=int(time.time()), data=image_data)
yield response.model_dump()
except Exception as e:
logger.error(f"Error in diffusion generation: {e}", exc_info=True)
error_response = {
"created": int(time.time()),
"data": [],
"error": str(e),
}
yield error_response
async def _generate_images(
self,
prompt: str,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
seed: Optional[int],
negative_prompt: Optional[str] = None,
) -> list[bytes]:
"""Generate images using SGLang DiffGenerator"""
args = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"save_output": False, # We handle saving ourselves
"guidance_scale": guidance_scale,
"seed": seed if seed else random.randint(0, 1000000),
}
result = await asyncio.to_thread(
self.generator.generate,
sampling_params_kwargs=args,
)
if result is None:
raise RuntimeError("No result from generator")
images = result["frames"] if "frames" in result else []
# Convert images to bytes (handle PIL Images, numpy arrays, or bytes)
image_bytes_list = []
for img in images:
if isinstance(img, bytes):
image_bytes_list.append(img)
elif Image is not None and isinstance(img, Image.Image):
# Convert PIL Image to bytes
buf = io.BytesIO()
img.save(buf, format="PNG")
image_bytes_list.append(buf.getvalue())
else:
try:
import numpy as np
if isinstance(img, np.ndarray):
# Convert numpy array to PIL Image then to bytes
pil_img = Image.fromarray(img)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
image_bytes_list.append(buf.getvalue())
else:
raise ValueError(f"Unsupported image type: {type(img)}")
except ImportError:
raise RuntimeError(
"Cannot convert image format. Install Pillow: pip install Pillow"
)
return image_bytes_list
def _parse_size(self, size_str: Optional[str]) -> tuple[int, int]:
"""Parse '1024x1024' -> (1024, 1024)"""
if size_str is None:
return 1024, 1024
w, h = size_str.split("x")
return int(w), int(h)
async def _upload_to_fs(
self, image_bytes: bytes, user_id: str, request_id: str
) -> str:
"""Upload image to filesystem and return URL.
Uses per-user storage path:
users/{user_id}/generations/{request_id}/{image_uuid}.png
Args:
image_bytes: Image data as bytes.
user_id: User identifier from request or context.
request_id: Request context ID.
Returns:
Public URL for the uploaded image.
"""
image_uuid = str(uuid.uuid4())
image_filename = f"{image_uuid}.png"
# Per-user storage path
storage_path = f"users/{user_id}/generations/{request_id}/{image_filename}"
# send image to filesystem
await asyncio.to_thread(self.fs.pipe, storage_path, image_bytes)
return f"{self.base_url}/{storage_path}"
def _encode_base64(self, image_bytes: bytes) -> str:
"""Encode image as base64 string"""
return base64.b64encode(image_bytes).decode("utf-8")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for ImageDiffusionWorkerHandler."""
import base64
import io
from unittest.mock import MagicMock, Mock, patch
import pytest
from PIL import Image
from dynamo.sglang.request_handlers.image_diffusion.image_diffusion_handler import (
ImageDiffusionWorkerHandler,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.sglang,
pytest.mark.gpu_0, # No GPU needed for unit tests
pytest.mark.pre_merge,
pytest.mark.parallel,
]
@pytest.fixture
def mock_component():
"""Mock Dynamo Component."""
return MagicMock()
@pytest.fixture
def mock_generator():
"""Mock SGLang DiffGenerator."""
generator = MagicMock()
generator.generate = MagicMock()
return generator
@pytest.fixture
def mock_config():
"""Mock Config object."""
config = MagicMock()
config.dynamo_args = MagicMock()
config.dynamo_args.image_diffusion_fs_url = "file:///tmp/images"
config.dynamo_args.image_diffusion_base_url = "file:///tmp/images"
return config
@pytest.fixture
def mock_fs():
"""Mock fsspec filesystem."""
fs = MagicMock()
fs.pipe = MagicMock()
return fs
@pytest.fixture
def mock_context():
"""Mock Context object."""
context = MagicMock()
context.id = MagicMock(return_value="test-context-id")
context.trace_id = "test-trace-id"
context.span_id = "test-span-id"
context.is_cancelled = MagicMock(return_value=False)
return context
@pytest.fixture
def handler(
mock_component, mock_generator, mock_config, mock_fs
) -> ImageDiffusionWorkerHandler:
"""Create ImageDiffusionWorkerHandler instance."""
return ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator,
config=mock_config,
publisher=None,
fs=mock_fs,
)
class TestImageDiffusionWorkerHandler:
"""Test suite for ImageDiffusionWorkerHandler."""
def test_initialization(self, handler, mock_generator, mock_fs):
"""Test handler initialization."""
assert handler.generator == mock_generator
assert handler.fs == mock_fs
assert handler.fs_url == "file:///tmp/images"
assert handler.base_url == "file:///tmp/images"
def test_initialization_with_url_base(
self, mock_component, mock_generator, mock_fs
):
"""Test handler initialization with URL base."""
config = MagicMock()
config.dynamo_args = MagicMock()
config.dynamo_args.image_diffusion_fs_url = "s3://my-bucket/images"
config.dynamo_args.image_diffusion_base_url = "http://localhost:8008/images"
handler = ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator,
config=config,
publisher=None,
fs=mock_fs,
)
assert handler.base_url == "http://localhost:8008/images"
assert handler.fs_url == "s3://my-bucket/images"
@patch("torch.cuda.empty_cache")
def test_cleanup(self, mock_empty_cache, handler):
"""Test cleanup method."""
_original_generator = handler.generator
handler.cleanup()
# Generator should be set to None after cleanup
# Note: We can't assert it's None because the attribute gets deleted
mock_empty_cache.assert_called_once()
def test_parse_size(self, handler):
"""Test _parse_size method."""
width, height = handler._parse_size("1024x1024")
assert width == 1024
assert height == 1024
width, height = handler._parse_size("512x768")
assert width == 512
assert height == 768
def test_encode_base64(self, handler):
"""Test _encode_base64 method."""
test_bytes = b"test image data"
expected = base64.b64encode(test_bytes).decode("utf-8")
result = handler._encode_base64(test_bytes)
assert result == expected
@pytest.mark.asyncio
async def test_generate_success_url_format(self, handler, mock_context):
"""Test successful image generation with URL response format."""
# Create a simple test image
test_image = Image.new("RGB", (256, 256), color="red")
img_buffer: io.BytesIO = io.BytesIO()
test_image.save(img_buffer, format="PNG")
# Mock generator response
handler.generator.generate = Mock(
return_value={"frames": [test_image.convert("RGB")]}
)
request = {
"prompt": "A red square",
"model": "test-model",
"size": "256x256",
"response_format": "url",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": None,
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify results
assert len(results) == 1
response = results[0]
assert "created" in response
assert "data" in response
assert len(response["data"]) == 1
assert "url" in response["data"][0]
assert response["data"][0]["url"].startswith("file:///tmp/images/users/")
@pytest.mark.asyncio
async def test_generate_success_b64_format(self, handler, mock_context):
"""Test successful image generation with base64 response format."""
# Create a simple test image
test_image = Image.new("RGB", (256, 256), color="blue")
# Mock generator response
handler.generator.generate = Mock(
return_value={"frames": [test_image.convert("RGB")]}
)
request = {
"prompt": "A blue square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": None,
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify results
assert len(results) == 1
response = results[0]
assert "created" in response
assert "data" in response
assert len(response["data"]) == 1
assert "b64_json" in response["data"][0]
# Verify it's valid base64
b64_data = response["data"][0]["b64_json"]
decoded = base64.b64decode(b64_data)
assert len(decoded) > 0
@pytest.mark.asyncio
async def test_generate_with_default_num_inference_steps(
self, handler, mock_context
):
"""Test that num_inference_steps defaults to 50."""
test_image = Image.new("RGB", (256, 256), color="green")
handler.generator.generate = Mock(return_value={"frames": [test_image]})
request = {
"prompt": "A green square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"user": "test-user",
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
@pytest.mark.asyncio
async def test_generate_error_handling(self, handler, mock_context):
"""Test error handling in generate method."""
# Mock generator to raise an exception
handler.generator.generate = Mock(side_effect=RuntimeError("Generation failed"))
request = {
"prompt": "Test prompt",
"model": "test-model",
"size": "256x256",
"response_format": "url",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": None,
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify error response
assert len(results) == 1
response = results[0]
assert "error" in response
assert "Generation failed" in response["error"]
assert response["data"] == []
@pytest.mark.asyncio
async def test_upload_to_fs(self, handler):
"""Test _upload_to_fs method."""
image_bytes = b"test image data"
user_id = "user123"
request_id = "req456"
url = await handler._upload_to_fs(image_bytes, user_id, request_id)
# Verify storage path format
assert f"users/{user_id}/generations/{request_id}/" in url
assert url.endswith(".png")
@pytest.mark.asyncio
async def test_generate_images_with_numpy_array(self, handler):
"""Test _generate_images handles numpy arrays."""
import numpy as np
# Create a numpy array representing an image
np_image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
handler.generator.generate = Mock(return_value={"frames": [np_image]})
images = await handler._generate_images(
prompt="test",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
)
assert len(images) == 1
assert isinstance(images[0], bytes)
@pytest.mark.asyncio
async def test_generate_images_with_pil_image(self, handler):
"""Test _generate_images handles PIL Images."""
pil_image = Image.new("RGB", (256, 256), color="red")
handler.generator.generate = Mock(return_value={"frames": [pil_image]})
images = await handler._generate_images(
prompt="test",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
)
assert len(images) == 1
assert isinstance(images[0], bytes)
@pytest.mark.asyncio
async def test_generate_images_with_bytes(self, handler):
"""Test _generate_images handles bytes directly."""
img_bytes = b"raw image bytes"
handler.generator.generate = Mock(return_value={"frames": [img_bytes]})
images = await handler._generate_images(
prompt="test",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
)
assert len(images) == 1
assert images[0] == img_bytes
@pytest.mark.asyncio
async def test_generate_with_nvext(self, handler, mock_context):
"""Test that nvext parameters are passed to the generator."""
test_image = Image.new("RGB", (256, 256), color="yellow")
handler._generate_images = Mock(return_value=[test_image.tobytes()])
handler._get_trace_header = Mock(
return_value={"traceparent": "00-1234567890-1234567890-01"}
)
request = {
"prompt": "A yellow square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": "negative",
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify results
handler._generate_images.assert_called_once_with(
prompt="A yellow square",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
negative_prompt="negative",
)
...@@ -96,15 +96,14 @@ pub enum ImageModeration { ...@@ -96,15 +96,14 @@ pub enum ImageModeration {
#[builder(derive(Debug))] #[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))] #[builder(build_fn(error = "OpenAIError"))]
pub struct CreateImageRequest { pub struct CreateImageRequest {
/// A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` /// A text description of the desired image(s).
/// and 4000 characters for `dall-e-3`.
pub prompt: String, pub prompt: String,
/// The model to use for image generation. /// The model to use for image generation.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<ImageModel>, pub model: Option<ImageModel>,
/// The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. /// The number of images to generate. Must be between 1 and 10.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u8>, // min:1 max:10 default:1 pub n: Option<u8>, // min:1 max:10 default:1
......
...@@ -269,6 +269,7 @@ fn register_llm<'p>( ...@@ -269,6 +269,7 @@ fn register_llm<'p>(
}; };
let is_tensor_based = model_type.inner.supports_tensor(); let is_tensor_based = model_type.inner.supports_tensor();
let is_images = model_type.inner.supports_images();
let model_type_obj = model_type.inner; let model_type_obj = model_type.inner;
...@@ -317,8 +318,9 @@ fn register_llm<'p>( ...@@ -317,8 +318,9 @@ fn register_llm<'p>(
.or_else(|| Some(source_path.clone())); .or_else(|| Some(source_path.clone()));
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// For TensorBased models, skip HuggingFace downloads and register directly // For TensorBased and Images models, skip HuggingFace downloads and register directly
if is_tensor_based { // These model types don't require tokenizers
if is_tensor_based || is_images {
let model_name = model_name.unwrap_or_else(|| source_path.clone()); let model_name = model_name.unwrap_or_else(|| source_path.clone());
let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name); let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name);
card.model_type = model_type_obj; card.model_type = model_type_obj;
...@@ -519,6 +521,10 @@ impl ModelType { ...@@ -519,6 +521,10 @@ impl ModelType {
const Prefill: Self = ModelType { const Prefill: Self = ModelType {
inner: llm_rs::model_type::ModelType::Prefill, inner: llm_rs::model_type::ModelType::Prefill,
}; };
#[classattr]
const Images: Self = ModelType {
inner: llm_rs::model_type::ModelType::Images,
};
fn __or__(&self, other: &Self) -> Self { fn __or__(&self, other: &Self) -> Self {
ModelType { ModelType {
......
...@@ -983,12 +983,13 @@ class ModelInput: ...@@ -983,12 +983,13 @@ class ModelInput:
... ...
class ModelType: class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding, Tensor or Prefill""" """What type of request this model needs: Chat, Completions, Embedding, Tensor, Images or Prefill"""
Chat: ModelType Chat: ModelType
Completions: ModelType Completions: ModelType
Embedding: ModelType Embedding: ModelType
TensorBased: ModelType TensorBased: ModelType
Prefill: ModelType Prefill: ModelType
Images: ModelType
... ...
class RouterMode: class RouterMode:
......
...@@ -33,7 +33,7 @@ use crate::{ ...@@ -33,7 +33,7 @@ use crate::{
openai::{ openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
}, },
}, },
}; };
...@@ -66,6 +66,7 @@ pub struct ModelManager { ...@@ -66,6 +66,7 @@ pub struct ModelManager {
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>, completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>, chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
images_engines: RwLock<ModelEngines<OpenAIImagesStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// Prefill models don't have engines - they're only tracked for discovery/lifecycle // Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>, prefill_engines: RwLock<ModelEngines<()>>,
...@@ -91,6 +92,7 @@ impl ModelManager { ...@@ -91,6 +92,7 @@ impl ModelManager {
completion_engines: RwLock::new(ModelEngines::default()), completion_engines: RwLock::new(ModelEngines::default()),
chat_completion_engines: RwLock::new(ModelEngines::default()), chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()), embeddings_engines: RwLock::new(ModelEngines::default()),
images_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()), tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()), prefill_engines: RwLock::new(ModelEngines::default()),
cards: DashMap::new(), cards: DashMap::new(),
...@@ -114,6 +116,7 @@ impl ModelManager { ...@@ -114,6 +116,7 @@ impl ModelManager {
ModelType::Completions => self.completion_engines.read().checksum(model_name), ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name), ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name), ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
ModelType::Images => self.images_engines.read().checksum(model_name),
ModelType::Prefill => self.prefill_engines.read().checksum(model_name), ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
_ => { _ => {
continue; continue;
...@@ -230,6 +233,16 @@ impl ModelManager { ...@@ -230,6 +233,16 @@ impl ModelManager {
clients.add(model, card_checksum, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_images_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIImagesStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write();
clients.add(model, card_checksum, engine)
}
pub fn add_prefill_model( pub fn add_prefill_model(
&self, &self,
model: &str, model: &str,
...@@ -259,6 +272,11 @@ impl ModelManager { ...@@ -259,6 +272,11 @@ impl ModelManager {
clients.remove(model) clients.remove(model)
} }
pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write();
clients.remove(model)
}
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write(); let mut clients = self.prefill_engines.write();
clients.remove(model) clients.remove(model)
...@@ -308,6 +326,17 @@ impl ModelManager { ...@@ -308,6 +326,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_images_engine(
&self,
model: &str,
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.images_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted. /// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
......
...@@ -38,6 +38,7 @@ use crate::{ ...@@ -38,6 +38,7 @@ use crate::{
}, },
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
images::{NvCreateImageRequest, NvImagesResponse},
}, },
tensor::{NvCreateTensorRequest, NvCreateTensorResponse}, tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
}, },
...@@ -619,6 +620,19 @@ impl ModelWatcher { ...@@ -619,6 +620,19 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_tensor_model(card.name(), checksum, engine)?; .add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_images() {
// Case: Text + Images (diffusion models)
// Takes text prompts as input, generates images
let push_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_images_model(card.name(), checksum, engine)?;
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
...@@ -656,7 +670,7 @@ impl ModelWatcher { ...@@ -656,7 +670,7 @@ impl ModelWatcher {
// Reject unsupported combinations // Reject unsupported combinations
anyhow::bail!( anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \ "Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions|Prefill), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased", Tokens+(Chat|Completions|Prefill), Text+(Chat|Completions|Images), Tokens+Embeddings, Tensor+TensorBased",
card.model_type, card.model_type,
card.model_input.as_str() card.model_input.as_str()
); );
......
...@@ -12,6 +12,8 @@ pub enum EndpointType { ...@@ -12,6 +12,8 @@ pub enum EndpointType {
Completion, Completion,
/// Embeddings API /// Embeddings API
Embedding, Embedding,
/// Images API (Diffusion/DALL-E)
Images,
/// Responses API /// Responses API
Responses, Responses,
} }
...@@ -22,6 +24,7 @@ impl EndpointType { ...@@ -22,6 +24,7 @@ impl EndpointType {
Self::Chat => "chat", Self::Chat => "chat",
Self::Completion => "completion", Self::Completion => "completion",
Self::Embedding => "embedding", Self::Embedding => "embedding",
Self::Images => "images",
Self::Responses => "responses", Self::Responses => "responses",
} }
} }
...@@ -31,6 +34,7 @@ impl EndpointType { ...@@ -31,6 +34,7 @@ impl EndpointType {
Self::Chat, Self::Chat,
Self::Completion, Self::Completion,
Self::Embedding, Self::Embedding,
Self::Images,
Self::Responses, Self::Responses,
] ]
} }
......
...@@ -219,6 +219,9 @@ pub enum Endpoint { ...@@ -219,6 +219,9 @@ pub enum Endpoint {
/// OAI Embeddings /// OAI Embeddings
Embeddings, Embeddings,
/// OAI Images
Images,
/// OAI Responses /// OAI Responses
Responses, Responses,
...@@ -840,6 +843,7 @@ impl std::fmt::Display for Endpoint { ...@@ -840,6 +843,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::Completions => write!(f, "completions"), Endpoint::Completions => write!(f, "completions"),
Endpoint::ChatCompletions => write!(f, "chat_completions"), Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"), Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Images => write!(f, "images"),
Endpoint::Responses => write!(f, "responses"), Endpoint::Responses => write!(f, "responses"),
Endpoint::Tensor => write!(f, "tensor"), Endpoint::Tensor => write!(f, "tensor"),
} }
...@@ -852,6 +856,7 @@ impl Endpoint { ...@@ -852,6 +856,7 @@ impl Endpoint {
Endpoint::Completions => "completions", Endpoint::Completions => "completions",
Endpoint::ChatCompletions => "chat_completions", Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings", Endpoint::Embeddings => "embeddings",
Endpoint::Images => "images",
Endpoint::Responses => "responses", Endpoint::Responses => "responses",
Endpoint::Tensor => "tensor", Endpoint::Tensor => "tensor",
} }
......
...@@ -49,6 +49,7 @@ use crate::protocols::openai::{ ...@@ -49,6 +49,7 @@ use crate::protocols::openai::{
}, },
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
images::{NvCreateImageRequest, NvImagesResponse},
responses::{NvCreateResponse, NvResponse}, responses::{NvCreateResponse, NvResponse},
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
...@@ -1548,6 +1549,99 @@ pub fn responses_router( ...@@ -1548,6 +1549,99 @@ pub fn responses_router(
(vec![doc], router) (vec![doc], router)
} }
async fn images(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateImageRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
// Images are typically not streamed, so we default to non-streaming
let streaming = false;
// Get the model name from the request (diffusion model)
let model = request
.inner
.model
.as_ref()
.map(|m| match m {
dynamo_async_openai::types::ImageModel::DallE2 => "dall-e-2".to_string(),
dynamo_async_openai::types::ImageModel::DallE3 => "dall-e-3".to_string(),
dynamo_async_openai::types::ImageModel::Other(s) => s.clone(),
})
.unwrap_or_else(|| "diffusion".to_string());
// Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
// Get the image generation engine
let engine = state
.manager()
.get_images_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
// this will increment the inflight gauge for the model
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Images, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
// Issue the generate call on the engine
// Note: This uses ServerStreamingEngine for internal routing/distribution,
// NOT for client-facing SSE streaming. The stream is immediately folded into
// a single response below.
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate images"))?;
// Process stream to collect metrics and drop http_queue_guard on first response
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
// Calls observe_response() on each item - drops http_queue_guard on first item
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
// Images are returned as a single response (non-streaming to client)
// Fold the internal stream into a single response
let response = NvImagesResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold images stream for {}: {:?}", request_id, e);
ErrorMessage::internal_server_error("Failed to fold images stream")
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
/// Create an Axum [`Router`] for the OpenAI API Images endpoint
/// If not path is provided, the default path is `/v1/images/generations`
pub fn images_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/images/generations".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(images))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
......
...@@ -46,6 +46,7 @@ struct StateFlags { ...@@ -46,6 +46,7 @@ struct StateFlags {
chat_endpoints_enabled: AtomicBool, chat_endpoints_enabled: AtomicBool,
cmpl_endpoints_enabled: AtomicBool, cmpl_endpoints_enabled: AtomicBool,
embeddings_endpoints_enabled: AtomicBool, embeddings_endpoints_enabled: AtomicBool,
images_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool, responses_endpoints_enabled: AtomicBool,
} }
...@@ -55,6 +56,7 @@ impl StateFlags { ...@@ -55,6 +56,7 @@ impl StateFlags {
EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
} }
} }
...@@ -70,6 +72,9 @@ impl StateFlags { ...@@ -70,6 +72,9 @@ impl StateFlags {
EndpointType::Embedding => self EndpointType::Embedding => self
.embeddings_endpoints_enabled .embeddings_endpoints_enabled
.store(enabled, Ordering::Relaxed), .store(enabled, Ordering::Relaxed),
EndpointType::Images => self
.images_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Responses => self EndpointType::Responses => self
.responses_endpoints_enabled .responses_endpoints_enabled
.store(enabled, Ordering::Relaxed), .store(enabled, Ordering::Relaxed),
...@@ -100,6 +105,7 @@ impl State { ...@@ -100,6 +105,7 @@ impl State {
chat_endpoints_enabled: AtomicBool::new(false), chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false), embeddings_endpoints_enabled: AtomicBool::new(false),
images_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false), responses_endpoints_enabled: AtomicBool::new(false),
}, },
cancel_token, cancel_token,
...@@ -509,6 +515,7 @@ impl HttpServiceConfigBuilder { ...@@ -509,6 +515,7 @@ impl HttpServiceConfigBuilder {
super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok()); super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
let (embed_docs, embed_route) = let (embed_docs, embed_route) =
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok()); super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (images_docs, images_route) = super::openai::images_router(state.clone(), None);
let (responses_docs, responses_route) = super::openai::responses_router( let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(), state.clone(),
request_template.clone(), request_template.clone(),
...@@ -519,6 +526,7 @@ impl HttpServiceConfigBuilder { ...@@ -519,6 +526,7 @@ impl HttpServiceConfigBuilder {
endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route)); endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route)); endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route)); endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Images, (images_docs, images_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route)); endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment