Unverified Commit 7e970d44 authored by Konrad Nowicki's avatar Konrad Nowicki Committed by GitHub
Browse files

feat: image diffusion with SGLang diffusion (#5609)


Signed-off-by: default avatarKonrad Nowicki <knowicki@nvidia.com>
Co-authored-by: default avatardagil-nvidia <dagil@nvidia.com>
parent f3aa1e01
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Dynamo Common Storage Module
Filesystem Spec (fsspec) is a project to provide a unified pythonic interface to local, remote and embedded file systems and bytes storage.
https://filesystem-spec.readthedocs.io/en/latest/index.html#who-uses-fsspec
Configuration for the storage:
Local Filesystem:
1. fs_url MUST contain a root path - path must be accessible and writable
S3:
1. If you want to use S3 please install additional dependencies: fsspec[s3]
2. fs_url MUST contain a bucket name
3. Configure credentials https://s3fs.readthedocs.io/en/latest/?badge=latest#credentials
a) AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables
b) configuration files such as ~/.aws/credentials
c) for nodes on EC2, the IAM metadata provider
d) for S3 compatible storage, you can use the following environment variables:
# export FSSPEC_S3_ENDPOINT_URL=https://...
# export FSSPEC_S3_KEY='miniokey...'
# export FSSPEC_S3_SECRET='asecretkey...'
"""
import fsspec
from fsspec.implementations.dirfs import DirFileSystem
def get_fs(fs_url: str) -> DirFileSystem:
"""
Initialize fsspec filesystem for the given URL.
Args:
fs_url: The URL of the filesystem to initialize. e.g. s3://bucket, gs://bucket, file:///local/path
Returns:
The initialized DirFileSystem wrapper for the filesystem.
fs.fs.protocol to get the protocol of the filesystem
fs.path to get the bucket or root path
path to the object in the filesystem - f"{fs.fs.protocol}://{fs.path}/{path}"
"""
# Extract protocol from URL (s3://, gs://, az://, file://)
fs_url_parts = fs_url.split("://")
protocol = fs_url_parts[0] if "://" in fs_url else "file"
# ... or bucket name
root_path = fs_url_parts[1] if len(fs_url_parts) > 1 else "/"
fs_opts = {}
if protocol in "file":
# create directory for local filesystem
fs_opts = {"auto_mkdir": True}
return DirFileSystem(fs=fsspec.filesystem(protocol, **fs_opts), path=root_path)
......@@ -131,6 +131,26 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": os.environ.get("DYN_LOCAL_INDEXER", "false"),
"help": "Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
},
"image-diffusion-worker": {
"flags": ["--image-diffusion-worker"],
"action": "store_true",
"default": False,
"help": "Run as image diffusion worker for image generation",
},
"image-diffusion-fs-url": {
"flags": ["--image-diffusion-fs-url"],
"type": str,
"default": None,
"help": "Filesystem URL for storing generated images using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.",
},
"image-diffusion-base-url": {
"flags": ["--image-diffusion-base-url"],
"type": str,
"default": os.environ.get(
"DYN_IMAGE_DIFFUSION_BASE_URL", "http://localhost:8008/"
),
"help": "Base URL for rewriting image URLs in responses (e.g., http://localhost:8008/). When set, generated image URLs will use this base instead of filesystem URLs. Can be set via DYN_IMAGE_DIFFUSION_URL_BASE env var.",
},
}
......@@ -173,6 +193,11 @@ class DynamoArgs:
# Whether to enable NATS for KV events (derived from server_args.kv_events_config)
use_kv_events: bool = False
# image diffusion options
image_diffusion_worker: bool = False
image_diffusion_fs_url: Optional[str] = None
image_diffusion_base_url: Optional[str] = None
class DisaggregationMode(Enum):
AGGREGATED = "agg"
......@@ -443,6 +468,8 @@ async def parse_args(args: list[str]) -> Config:
if endpoint is None:
if parsed_args.embedding_worker:
endpoint = f"dyn://{namespace}.backend.generate"
elif getattr(parsed_args, "image_diffusion_worker", False):
endpoint = f"dyn://{namespace}.backend.generate"
elif (
hasattr(parsed_args, "disaggregation_mode")
and parsed_args.disaggregation_mode == "prefill"
......@@ -521,6 +548,35 @@ async def parse_args(args: list[str]) -> Config:
# TODO: sglang downloads the model in `from_cli_args`, which means we had to
# fetch_llm (download the model) here, in `parse_args`. `parse_args` should not
# contain code to download a model, it should only parse the args.
# For diffusion workers, create a minimal dummy ServerArgs since diffusion
# doesn't use transformer models or sglang Engine - it uses DiffGenerator directly
image_diffusion_worker = getattr(parsed_args, "image_diffusion_worker", False)
if image_diffusion_worker:
logging.info(f"Image diffusion worker detected with model: {model_path}")
# Need to use ServerArgs not intended for sglang[diffusion], multimodal_gen has its own ServerArgs.
server_args = ServerArgs("none") # HACK: Avoid triggering __post_init__
server_args.model_path = model_path
server_args.served_model_name = parsed_args.served_model_name
server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False)
server_args.log_level = getattr(parsed_args, "log_level", "info")
server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None)
server_args.speculative_algorithm = None
server_args.disaggregation_mode = None
server_args.dllm_algorithm = False
server_args.tp_size = getattr(parsed_args, "tensor_parallel_size", 1)
server_args.dp_size = getattr(parsed_args, "data_parallel_size", 1)
parsed_args.use_sglang_tokenizer = True
parsed_args.dyn_endpoint_types = "images"
logging.info(
f"Created stub ServerArgs for diffusion: model_path={server_args.model_path}"
)
else:
server_args = ServerArgs.from_cli_args(parsed_args)
# Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new
......@@ -576,6 +632,9 @@ async def parse_args(args: list[str]) -> Config:
multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker,
diffusion_worker=diffusion_worker,
image_diffusion_worker=getattr(parsed_args, "image_diffusion_worker", False),
image_diffusion_fs_url=getattr(parsed_args, "image_diffusion_fs_url", None),
image_diffusion_base_url=getattr(parsed_args, "image_diffusion_base_url", None),
dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
use_kv_events=use_kv_events,
......
......@@ -118,3 +118,29 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore
super().__init__()
class ImageDiffusionHealthCheckPayload(HealthCheckPayload):
"""Image diffusion-specific health check payload for image generation workers.
Sends a minimal image generation request to verify the diffusion worker
is responding and the model is loaded. Uses minimal resources for fast checks.
"""
def __init__(self, model_path: str):
"""Initialize diffusion health check payload with minimal generation request.
Args:
model_path: The diffusion model being served.
"""
self.default_payload = {
"prompt": "test", # Minimal prompt
"model": model_path,
"n": 1, # Generate 1 image
"size": "512x512", # Small size for fast health check
"num_inference_steps": 1, # Just 1 step (fast but low quality)
"guidance_scale": 7.5, # Standard guidance scale
"response_format": "b64_json", # Don't require S3 for health check
}
super().__init__()
......@@ -11,12 +11,14 @@ import sglang as sgl
import uvloop
from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.health_check import (
ImageDiffusionHealthCheckPayload,
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
......@@ -25,11 +27,15 @@ from dynamo.sglang.publisher import (
setup_prometheus_registry,
setup_sgl_metrics,
)
from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.register import (
register_image_diffusion_model,
register_llm_with_readiness_gate,
)
from dynamo.sglang.request_handlers import (
DecodeWorkerHandler,
DiffusionWorkerHandler,
EmbeddingWorkerHandler,
ImageDiffusionWorkerHandler,
MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
......@@ -113,7 +119,9 @@ async def worker():
logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
if config.dynamo_args.embedding_worker:
if config.dynamo_args.image_diffusion_worker:
await init_image_diffusion(runtime, config)
elif config.dynamo_args.embedding_worker:
await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config)
......@@ -128,6 +136,7 @@ async def worker():
await init_diffusion(runtime, config)
elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config)
else:
await init_prefill(runtime, config)
......@@ -446,6 +455,87 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
handler.cleanup()
async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
"""Initialize image diffusion worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
# Initialize DiffGenerator (not sgl.Engine)
from sglang.multimodal_gen import DiffGenerator
if not server_args.model_path:
raise ValueError("--model is required for diffusion workers")
# Parallelism configuration
tp_size = getattr(server_args, "tp_size", 1)
dp_size = getattr(server_args, "dp_size", 1)
num_gpus = tp_size * dp_size
# Distributed configuration
dist_timeout = getattr(server_args, "dist_timeout", None)
generator = DiffGenerator.from_pretrained(
model_path=server_args.model_path,
# Parallelism configuration
num_gpus=num_gpus,
tp_size=tp_size,
dp_size=dp_size,
# Distributed configuration
dist_timeout=dist_timeout,
)
# Initialize fsspec filesystems for image storage
fs_url = dynamo_args.image_diffusion_fs_url
# Initialize primary filesystem
if not fs_url:
raise ValueError("--image-diffusion-fs-url is required for diffusion workers")
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later
handler = ImageDiffusionWorkerHandler(
component,
generator,
config,
publisher=None,
fs=get_fs(fs_url),
)
# Create proper health check payload that sends a minimal diffusion request
health_check_payload = ImageDiffusionHealthCheckPayload(
model_path=server_args.model_path
).to_dict()
ready_event = asyncio.Event()
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[], # No LLM metrics labels
health_check_payload=health_check_payload,
),
register_image_diffusion_model(
generator,
generate_endpoint,
server_args,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve image diffusion endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......
......@@ -131,3 +131,46 @@ class DisaggSglangMultimodalRequest(BaseModel):
request: SglangMultimodalRequest
sampling_params: dict
data_parallel_rank: Optional[int] = None
# ============================================================================
# Image diffusion Protocol Types
# ============================================================================
class NvExt(BaseModel):
"""NVIDIA extensions for image generation"""
negative_prompt: Optional[str] = None
num_inference_steps: Optional[int] = 50
guidance_scale: float = 7.5
seed: Optional[int] = None
annotations: Optional[list[str]] = None
class CreateImageRequest(BaseModel):
"""OpenAI /v1/images/generations compatible request"""
prompt: str
model: str # e.g. "stabilityai/stable-diffusion-3.5-medium"
n: int = 1 # Number of images
size: Optional[str] = "1024x1024" # "WxH" format
quality: Optional[str] = "standard" # standard, hd
response_format: Optional[str] = "url" # url or b64_json
user: Optional[str] = None
# NVIDIA extensions nested under nvext
nvext: Optional[NvExt] = None
class ImageData(BaseModel):
url: Optional[str] = None # S3 URL
b64_json: Optional[str] = None # Base64 encoded
revised_prompt: Optional[str] = None
class ImagesResponse(BaseModel):
"""OpenAI-compatible response"""
created: int # Unix timestamp
data: list[ImageData]
......@@ -4,7 +4,7 @@
import asyncio
import logging
import socket
from typing import Optional
from typing import Any, Optional
import sglang as sgl
from sglang.srt.server_args import ServerArgs
......@@ -269,3 +269,43 @@ async def register_llm_with_readiness_gate(
readiness_gate.set()
logging.info("Model registration succeeded; processing queued requests")
async def register_image_diffusion_model(
generator: Any, # DiffGenerator
endpoint: Endpoint,
server_args: ServerArgs,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Register diffusion model with Dynamo runtime.
Args:
generator: The SGLang DiffGenerator instance.
endpoint: The Dynamo endpoint for generation requests.
server_args: SGLang server configuration.
readiness_gate: Optional event to signal when registration completes.
Note:
Image diffusion models use ModelInput.Text (text prompts) and ModelType.Images.
"""
# Use model_path as the model name (diffusion workers don't have served_model_name)
model_name = server_args.model_path
try:
await register_llm(
ModelInput.Text,
ModelType.Images,
endpoint,
model_name,
model_name,
)
logging.info(f"Successfully registered diffusion model: {model_name}")
except Exception as e:
logging.error(f"Failed to register diffusion model: {e}")
raise RuntimeError("Image diffusion model registration failed")
# Signal readiness
if readiness_gate:
readiness_gate.set()
logging.info(f"Image diffusion model ready: {model_name}")
......@@ -5,7 +5,10 @@
from .embedding import EmbeddingWorkerHandler
# Base handlers
from .handler_base import BaseWorkerHandler
from .handler_base import BaseGenerativeHandler, BaseWorkerHandler
# Image diffusion handlers
from .image_diffusion import ImageDiffusionWorkerHandler
# LLM handlers
from .llm import DecodeWorkerHandler, DiffusionWorkerHandler, PrefillWorkerHandler
......@@ -19,6 +22,8 @@ from .multimodal import (
)
__all__ = [
# Base handlers
"BaseGenerativeHandler",
"BaseWorkerHandler",
# LLM handlers
"DecodeWorkerHandler",
......@@ -26,6 +31,8 @@ __all__ = [
"PrefillWorkerHandler",
# Embedding handlers
"EmbeddingWorkerHandler",
# Image diffusion handlers
"ImageDiffusionWorkerHandler",
# Multimodal handlers
"MultimodalEncodeWorkerHandler",
"MultimodalPrefillWorkerHandler",
......
......@@ -18,8 +18,82 @@ from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
class BaseWorkerHandler(ABC):
"""Abstract base class for SGLang worker handlers."""
class BaseGenerativeHandler(ABC):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
Provides common infrastructure for:
- Component and configuration management
- Metrics and KV event publishing
- Distributed tracing integration
"""
def __init__(
self,
component: Component,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
) -> None:
"""Initialize base generative handler.
Args:
component: The Dynamo runtime component.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
"""
self.component = component
self.config = config
# Set up metrics and KV publishers
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
@abstractmethod
async def generate(
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate response from request.
Args:
request: Request dict with input and parameters.
context: Context object for cancellation handling.
Yields:
Response data (format varies by handler implementation).
"""
pass
def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed."""
pass
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to generation functions.
Args:
context: Dynamo Context object containing trace information.
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
class BaseWorkerHandler(BaseGenerativeHandler):
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
- SGLang Engine integration
- Tokenization and input parameter management
- Disaggregated serving support
"""
def __init__(
self,
......@@ -38,7 +112,10 @@ class BaseWorkerHandler(ABC):
publisher: Optional metrics publisher for the worker.
generate_endpoint: The endpoint handle for discovery registration.
"""
self.component = component
# Call parent constructor
super().__init__(component, config, publisher)
# LLM-specific initialization
self.engine = engine
self.config = config
self.generate_endpoint = generate_endpoint
......@@ -274,21 +351,6 @@ class BaseWorkerHandler(ABC):
return bootstrap_host, bootstrap_port
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to SGLang's external_trace_header parameter.
Args:
context: Dynamo Context object containing trace information.
Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
async def _handle_cancellation(
self, request_id_future: asyncio.Future, context: Context
):
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .image_diffusion_handler import ImageDiffusionWorkerHandler
__all__ = ["ImageDiffusionWorkerHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import io
import logging
import random
import time
import uuid
from typing import Any, AsyncGenerator, Optional
import torch
from PIL import Image
from dynamo._core import Component, Context
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseGenerativeHandler
logger = logging.getLogger(__name__)
MAX_NUM_INFERENCE_STEPS = 50
class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
"""Handler for diffusion image generation.
Inherits from BaseGenerativeHandler for common infrastructure like
tracing, metrics publishing
"""
def __init__(
self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
fs: Any = None, # fsspec.AbstractFileSystem for primary storage
):
"""Initialize diffusion worker handler.
Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for diffusion currently).
fs: Optional fsspec filesystem for primary image storage.
"""
super().__init__(component, config, publisher)
self.generator = generator # DiffGenerator, not Engine
self.fs = fs
self.fs_url = config.dynamo_args.image_diffusion_fs_url
self.base_url = config.dynamo_args.image_diffusion_base_url
logger.info(
f"Image diffusion worker handler initialized with fs_url={self.fs_url}, url_base={self.base_url}"
)
def cleanup(self) -> None:
"""Cleanup generator resources"""
if self.generator is not None:
del self.generator
torch.cuda.empty_cache()
logger.info("Image diffusion generator cleanup complete")
super().cleanup()
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""
Generate image(s) from text prompt.
Unlike LLM streaming, diffusion returns complete image(s) at end.
Args:
request: Request dict with prompt and generation parameters.
context: Context object for cancellation handling.
Yields:
Response dict with generated images (OpenAI-compatible format).
"""
logger.info(f"Image diffusion request: {request}")
# Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context)
if trace_header:
logger.debug(f"Image diffusion request with trace: {trace_header}")
try:
req = CreateImageRequest(**request)
# get extra parameters
nvext = req.nvext or NvExt()
nvext.num_inference_steps = min(
nvext.num_inference_steps or 50, MAX_NUM_INFERENCE_STEPS
)
width, height = self._parse_size(req.size)
images = await self._generate_images(
prompt=req.prompt,
negative_prompt=nvext.negative_prompt,
width=width,
height=height,
num_inference_steps=nvext.num_inference_steps,
guidance_scale=nvext.guidance_scale,
seed=nvext.seed,
)
user_id = req.user if req.user else context.id()
image_data = []
for img in images:
# uploading or encoding the image
if req.response_format == "url":
url = await self._upload_to_fs(img, user_id, context.id())
image_data.append(ImageData(url=url))
else:
b64 = self._encode_base64(img)
image_data.append(ImageData(b64_json=b64))
response = ImagesResponse(created=int(time.time()), data=image_data)
yield response.model_dump()
except Exception as e:
logger.error(f"Error in diffusion generation: {e}", exc_info=True)
error_response = {
"created": int(time.time()),
"data": [],
"error": str(e),
}
yield error_response
async def _generate_images(
self,
prompt: str,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
seed: Optional[int],
negative_prompt: Optional[str] = None,
) -> list[bytes]:
"""Generate images using SGLang DiffGenerator"""
args = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"save_output": False, # We handle saving ourselves
"guidance_scale": guidance_scale,
"seed": seed if seed else random.randint(0, 1000000),
}
result = await asyncio.to_thread(
self.generator.generate,
sampling_params_kwargs=args,
)
if result is None:
raise RuntimeError("No result from generator")
images = result["frames"] if "frames" in result else []
# Convert images to bytes (handle PIL Images, numpy arrays, or bytes)
image_bytes_list = []
for img in images:
if isinstance(img, bytes):
image_bytes_list.append(img)
elif Image is not None and isinstance(img, Image.Image):
# Convert PIL Image to bytes
buf = io.BytesIO()
img.save(buf, format="PNG")
image_bytes_list.append(buf.getvalue())
else:
try:
import numpy as np
if isinstance(img, np.ndarray):
# Convert numpy array to PIL Image then to bytes
pil_img = Image.fromarray(img)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
image_bytes_list.append(buf.getvalue())
else:
raise ValueError(f"Unsupported image type: {type(img)}")
except ImportError:
raise RuntimeError(
"Cannot convert image format. Install Pillow: pip install Pillow"
)
return image_bytes_list
def _parse_size(self, size_str: Optional[str]) -> tuple[int, int]:
"""Parse '1024x1024' -> (1024, 1024)"""
if size_str is None:
return 1024, 1024
w, h = size_str.split("x")
return int(w), int(h)
async def _upload_to_fs(
self, image_bytes: bytes, user_id: str, request_id: str
) -> str:
"""Upload image to filesystem and return URL.
Uses per-user storage path:
users/{user_id}/generations/{request_id}/{image_uuid}.png
Args:
image_bytes: Image data as bytes.
user_id: User identifier from request or context.
request_id: Request context ID.
Returns:
Public URL for the uploaded image.
"""
image_uuid = str(uuid.uuid4())
image_filename = f"{image_uuid}.png"
# Per-user storage path
storage_path = f"users/{user_id}/generations/{request_id}/{image_filename}"
# send image to filesystem
await asyncio.to_thread(self.fs.pipe, storage_path, image_bytes)
return f"{self.base_url}/{storage_path}"
def _encode_base64(self, image_bytes: bytes) -> str:
"""Encode image as base64 string"""
return base64.b64encode(image_bytes).decode("utf-8")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for ImageDiffusionWorkerHandler."""
import base64
import io
from unittest.mock import MagicMock, Mock, patch
import pytest
from PIL import Image
from dynamo.sglang.request_handlers.image_diffusion.image_diffusion_handler import (
ImageDiffusionWorkerHandler,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.sglang,
pytest.mark.gpu_0, # No GPU needed for unit tests
pytest.mark.pre_merge,
pytest.mark.parallel,
]
@pytest.fixture
def mock_component():
"""Mock Dynamo Component."""
return MagicMock()
@pytest.fixture
def mock_generator():
"""Mock SGLang DiffGenerator."""
generator = MagicMock()
generator.generate = MagicMock()
return generator
@pytest.fixture
def mock_config():
"""Mock Config object."""
config = MagicMock()
config.dynamo_args = MagicMock()
config.dynamo_args.image_diffusion_fs_url = "file:///tmp/images"
config.dynamo_args.image_diffusion_base_url = "file:///tmp/images"
return config
@pytest.fixture
def mock_fs():
"""Mock fsspec filesystem."""
fs = MagicMock()
fs.pipe = MagicMock()
return fs
@pytest.fixture
def mock_context():
"""Mock Context object."""
context = MagicMock()
context.id = MagicMock(return_value="test-context-id")
context.trace_id = "test-trace-id"
context.span_id = "test-span-id"
context.is_cancelled = MagicMock(return_value=False)
return context
@pytest.fixture
def handler(
mock_component, mock_generator, mock_config, mock_fs
) -> ImageDiffusionWorkerHandler:
"""Create ImageDiffusionWorkerHandler instance."""
return ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator,
config=mock_config,
publisher=None,
fs=mock_fs,
)
class TestImageDiffusionWorkerHandler:
"""Test suite for ImageDiffusionWorkerHandler."""
def test_initialization(self, handler, mock_generator, mock_fs):
"""Test handler initialization."""
assert handler.generator == mock_generator
assert handler.fs == mock_fs
assert handler.fs_url == "file:///tmp/images"
assert handler.base_url == "file:///tmp/images"
def test_initialization_with_url_base(
self, mock_component, mock_generator, mock_fs
):
"""Test handler initialization with URL base."""
config = MagicMock()
config.dynamo_args = MagicMock()
config.dynamo_args.image_diffusion_fs_url = "s3://my-bucket/images"
config.dynamo_args.image_diffusion_base_url = "http://localhost:8008/images"
handler = ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator,
config=config,
publisher=None,
fs=mock_fs,
)
assert handler.base_url == "http://localhost:8008/images"
assert handler.fs_url == "s3://my-bucket/images"
@patch("torch.cuda.empty_cache")
def test_cleanup(self, mock_empty_cache, handler):
"""Test cleanup method."""
_original_generator = handler.generator
handler.cleanup()
# Generator should be set to None after cleanup
# Note: We can't assert it's None because the attribute gets deleted
mock_empty_cache.assert_called_once()
def test_parse_size(self, handler):
"""Test _parse_size method."""
width, height = handler._parse_size("1024x1024")
assert width == 1024
assert height == 1024
width, height = handler._parse_size("512x768")
assert width == 512
assert height == 768
def test_encode_base64(self, handler):
"""Test _encode_base64 method."""
test_bytes = b"test image data"
expected = base64.b64encode(test_bytes).decode("utf-8")
result = handler._encode_base64(test_bytes)
assert result == expected
@pytest.mark.asyncio
async def test_generate_success_url_format(self, handler, mock_context):
"""Test successful image generation with URL response format."""
# Create a simple test image
test_image = Image.new("RGB", (256, 256), color="red")
img_buffer: io.BytesIO = io.BytesIO()
test_image.save(img_buffer, format="PNG")
# Mock generator response
handler.generator.generate = Mock(
return_value={"frames": [test_image.convert("RGB")]}
)
request = {
"prompt": "A red square",
"model": "test-model",
"size": "256x256",
"response_format": "url",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": None,
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify results
assert len(results) == 1
response = results[0]
assert "created" in response
assert "data" in response
assert len(response["data"]) == 1
assert "url" in response["data"][0]
assert response["data"][0]["url"].startswith("file:///tmp/images/users/")
@pytest.mark.asyncio
async def test_generate_success_b64_format(self, handler, mock_context):
"""Test successful image generation with base64 response format."""
# Create a simple test image
test_image = Image.new("RGB", (256, 256), color="blue")
# Mock generator response
handler.generator.generate = Mock(
return_value={"frames": [test_image.convert("RGB")]}
)
request = {
"prompt": "A blue square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": None,
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify results
assert len(results) == 1
response = results[0]
assert "created" in response
assert "data" in response
assert len(response["data"]) == 1
assert "b64_json" in response["data"][0]
# Verify it's valid base64
b64_data = response["data"][0]["b64_json"]
decoded = base64.b64decode(b64_data)
assert len(decoded) > 0
@pytest.mark.asyncio
async def test_generate_with_default_num_inference_steps(
self, handler, mock_context
):
"""Test that num_inference_steps defaults to 50."""
test_image = Image.new("RGB", (256, 256), color="green")
handler.generator.generate = Mock(return_value={"frames": [test_image]})
request = {
"prompt": "A green square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"user": "test-user",
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
@pytest.mark.asyncio
async def test_generate_error_handling(self, handler, mock_context):
"""Test error handling in generate method."""
# Mock generator to raise an exception
handler.generator.generate = Mock(side_effect=RuntimeError("Generation failed"))
request = {
"prompt": "Test prompt",
"model": "test-model",
"size": "256x256",
"response_format": "url",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": None,
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify error response
assert len(results) == 1
response = results[0]
assert "error" in response
assert "Generation failed" in response["error"]
assert response["data"] == []
@pytest.mark.asyncio
async def test_upload_to_fs(self, handler):
"""Test _upload_to_fs method."""
image_bytes = b"test image data"
user_id = "user123"
request_id = "req456"
url = await handler._upload_to_fs(image_bytes, user_id, request_id)
# Verify storage path format
assert f"users/{user_id}/generations/{request_id}/" in url
assert url.endswith(".png")
@pytest.mark.asyncio
async def test_generate_images_with_numpy_array(self, handler):
"""Test _generate_images handles numpy arrays."""
import numpy as np
# Create a numpy array representing an image
np_image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
handler.generator.generate = Mock(return_value={"frames": [np_image]})
images = await handler._generate_images(
prompt="test",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
)
assert len(images) == 1
assert isinstance(images[0], bytes)
@pytest.mark.asyncio
async def test_generate_images_with_pil_image(self, handler):
"""Test _generate_images handles PIL Images."""
pil_image = Image.new("RGB", (256, 256), color="red")
handler.generator.generate = Mock(return_value={"frames": [pil_image]})
images = await handler._generate_images(
prompt="test",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
)
assert len(images) == 1
assert isinstance(images[0], bytes)
@pytest.mark.asyncio
async def test_generate_images_with_bytes(self, handler):
"""Test _generate_images handles bytes directly."""
img_bytes = b"raw image bytes"
handler.generator.generate = Mock(return_value={"frames": [img_bytes]})
images = await handler._generate_images(
prompt="test",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
)
assert len(images) == 1
assert images[0] == img_bytes
@pytest.mark.asyncio
async def test_generate_with_nvext(self, handler, mock_context):
"""Test that nvext parameters are passed to the generator."""
test_image = Image.new("RGB", (256, 256), color="yellow")
handler._generate_images = Mock(return_value=[test_image.tobytes()])
handler._get_trace_header = Mock(
return_value={"traceparent": "00-1234567890-1234567890-01"}
)
request = {
"prompt": "A yellow square",
"model": "test-model",
"size": "256x256",
"response_format": "b64_json",
"user": "test-user",
"nvext": {
"num_inference_steps": 10,
"guidance_scale": 7.5,
"seed": 42,
"negative_prompt": "negative",
},
}
# Execute generation
results = []
async for result in handler.generate(request, mock_context):
results.append(result)
# Verify results
handler._generate_images.assert_called_once_with(
prompt="A yellow square",
width=256,
height=256,
num_inference_steps=10,
guidance_scale=7.5,
seed=42,
negative_prompt="negative",
)
......@@ -96,15 +96,14 @@ pub enum ImageModeration {
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateImageRequest {
/// A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2`
/// and 4000 characters for `dall-e-3`.
/// A text description of the desired image(s).
pub prompt: String,
/// The model to use for image generation.
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<ImageModel>,
/// The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.
/// The number of images to generate. Must be between 1 and 10.
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u8>, // min:1 max:10 default:1
......
......@@ -269,6 +269,7 @@ fn register_llm<'p>(
};
let is_tensor_based = model_type.inner.supports_tensor();
let is_images = model_type.inner.supports_images();
let model_type_obj = model_type.inner;
......@@ -317,8 +318,9 @@ fn register_llm<'p>(
.or_else(|| Some(source_path.clone()));
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// For TensorBased models, skip HuggingFace downloads and register directly
if is_tensor_based {
// For TensorBased and Images models, skip HuggingFace downloads and register directly
// These model types don't require tokenizers
if is_tensor_based || is_images {
let model_name = model_name.unwrap_or_else(|| source_path.clone());
let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name);
card.model_type = model_type_obj;
......@@ -519,6 +521,10 @@ impl ModelType {
const Prefill: Self = ModelType {
inner: llm_rs::model_type::ModelType::Prefill,
};
#[classattr]
const Images: Self = ModelType {
inner: llm_rs::model_type::ModelType::Images,
};
fn __or__(&self, other: &Self) -> Self {
ModelType {
......
......@@ -983,12 +983,13 @@ class ModelInput:
...
class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding, Tensor or Prefill"""
"""What type of request this model needs: Chat, Completions, Embedding, Tensor, Images or Prefill"""
Chat: ModelType
Completions: ModelType
Embedding: ModelType
TensorBased: ModelType
Prefill: ModelType
Images: ModelType
...
class RouterMode:
......
......@@ -33,7 +33,7 @@ use crate::{
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
},
},
};
......@@ -66,6 +66,7 @@ pub struct ModelManager {
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
images_engines: RwLock<ModelEngines<OpenAIImagesStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>,
......@@ -91,6 +92,7 @@ impl ModelManager {
completion_engines: RwLock::new(ModelEngines::default()),
chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()),
images_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()),
cards: DashMap::new(),
......@@ -114,6 +116,7 @@ impl ModelManager {
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
ModelType::Images => self.images_engines.read().checksum(model_name),
ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
_ => {
continue;
......@@ -230,6 +233,16 @@ impl ModelManager {
clients.add(model, card_checksum, engine)
}
pub fn add_images_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIImagesStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write();
clients.add(model, card_checksum, engine)
}
pub fn add_prefill_model(
&self,
model: &str,
......@@ -259,6 +272,11 @@ impl ModelManager {
clients.remove(model)
}
pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write();
clients.remove(model)
}
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.remove(model)
......@@ -308,6 +326,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_images_engine(
&self,
model: &str,
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.images_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
......
......@@ -38,6 +38,7 @@ use crate::{
},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
images::{NvCreateImageRequest, NvImagesResponse},
},
tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
},
......@@ -619,6 +620,19 @@ impl ModelWatcher {
let engine = Arc::new(push_router);
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_images() {
// Case: Text + Images (diffusion models)
// Takes text prompts as input, generates images
let push_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_images_model(card.name(), checksum, engine)?;
} else if card.model_type.supports_prefill() {
// Case 6: Prefill
// Guardrail: Verify model_input is Tokens
......@@ -656,7 +670,7 @@ impl ModelWatcher {
// Reject unsupported combinations
anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions|Prefill), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
Tokens+(Chat|Completions|Prefill), Text+(Chat|Completions|Images), Tokens+Embeddings, Tensor+TensorBased",
card.model_type,
card.model_input.as_str()
);
......
......@@ -12,6 +12,8 @@ pub enum EndpointType {
Completion,
/// Embeddings API
Embedding,
/// Images API (Diffusion/DALL-E)
Images,
/// Responses API
Responses,
}
......@@ -22,6 +24,7 @@ impl EndpointType {
Self::Chat => "chat",
Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Images => "images",
Self::Responses => "responses",
}
}
......@@ -31,6 +34,7 @@ impl EndpointType {
Self::Chat,
Self::Completion,
Self::Embedding,
Self::Images,
Self::Responses,
]
}
......
......@@ -219,6 +219,9 @@ pub enum Endpoint {
/// OAI Embeddings
Embeddings,
/// OAI Images
Images,
/// OAI Responses
Responses,
......@@ -840,6 +843,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::Completions => write!(f, "completions"),
Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Images => write!(f, "images"),
Endpoint::Responses => write!(f, "responses"),
Endpoint::Tensor => write!(f, "tensor"),
}
......@@ -852,6 +856,7 @@ impl Endpoint {
Endpoint::Completions => "completions",
Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings",
Endpoint::Images => "images",
Endpoint::Responses => "responses",
Endpoint::Tensor => "tensor",
}
......
......@@ -49,6 +49,7 @@ use crate::protocols::openai::{
},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
images::{NvCreateImageRequest, NvImagesResponse},
responses::{NvCreateResponse, NvResponse},
};
use crate::request_template::RequestTemplate;
......@@ -1548,6 +1549,99 @@ pub fn responses_router(
(vec![doc], router)
}
async fn images(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateImageRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
// Images are typically not streamed, so we default to non-streaming
let streaming = false;
// Get the model name from the request (diffusion model)
let model = request
.inner
.model
.as_ref()
.map(|m| match m {
dynamo_async_openai::types::ImageModel::DallE2 => "dall-e-2".to_string(),
dynamo_async_openai::types::ImageModel::DallE3 => "dall-e-3".to_string(),
dynamo_async_openai::types::ImageModel::Other(s) => s.clone(),
})
.unwrap_or_else(|| "diffusion".to_string());
// Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
// Get the image generation engine
let engine = state
.manager()
.get_images_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
// this will increment the inflight gauge for the model
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Images, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
// Issue the generate call on the engine
// Note: This uses ServerStreamingEngine for internal routing/distribution,
// NOT for client-facing SSE streaming. The stream is immediately folded into
// a single response below.
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate images"))?;
// Process stream to collect metrics and drop http_queue_guard on first response
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
// Calls observe_response() on each item - drops http_queue_guard on first item
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
// Images are returned as a single response (non-streaming to client)
// Fold the internal stream into a single response
let response = NvImagesResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold images stream for {}: {:?}", request_id, e);
ErrorMessage::internal_server_error("Failed to fold images stream")
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
/// Create an Axum [`Router`] for the OpenAI API Images endpoint
/// If not path is provided, the default path is `/v1/images/generations`
pub fn images_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/images/generations".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(images))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
#[cfg(test)]
mod tests {
......
......@@ -46,6 +46,7 @@ struct StateFlags {
chat_endpoints_enabled: AtomicBool,
cmpl_endpoints_enabled: AtomicBool,
embeddings_endpoints_enabled: AtomicBool,
images_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool,
}
......@@ -55,6 +56,7 @@ impl StateFlags {
EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
}
}
......@@ -70,6 +72,9 @@ impl StateFlags {
EndpointType::Embedding => self
.embeddings_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Images => self
.images_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Responses => self
.responses_endpoints_enabled
.store(enabled, Ordering::Relaxed),
......@@ -100,6 +105,7 @@ impl State {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false),
images_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
cancel_token,
......@@ -509,6 +515,7 @@ impl HttpServiceConfigBuilder {
super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
let (embed_docs, embed_route) =
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (images_docs, images_route) = super::openai::images_router(state.clone(), None);
let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(),
request_template.clone(),
......@@ -519,6 +526,7 @@ impl HttpServiceConfigBuilder {
endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Images, (images_docs, images_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
for endpoint_type in EndpointType::all() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment