"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "32b14baf8a1f7195ca09484de3008063569b43c5"
Unverified Commit 1b0334c8 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: add TRTLLM text-to-image support (#8200)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 5135c321
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Image utilities for image diffusion.
Provides helpers for encoding numpy images to PNG format.
"""
import io
import logging
import numpy as np
logger = logging.getLogger(__name__)
def encode_to_png_bytes(
image: np.ndarray,
) -> bytes:
"""Encode numpy image to PNG bytes (in-memory).
Args:
image: Numpy image array of shape (H, W, 3) with uint8 values 0-255.
Returns:
PNG-encoded bytes.
Raises:
ImportError: If Pillow is not available.
RuntimeError: If encoding fails.
"""
try:
from PIL import Image
except ImportError:
raise ImportError(
"Pillow is required for PNG encoding. " "Install with: pip install Pillow"
) from None
logger.info(f"Encoding image of shape {image.shape} to PNG")
try:
img = Image.fromarray(image)
buf = io.BytesIO()
img.save(buf, format="PNG")
image_bytes = buf.getvalue()
logger.info(f"Encoded PNG image to {len(image_bytes)} bytes")
return image_bytes
except Exception as e:
logger.error(f"Failed to encode image to PNG bytes: {e}")
raise RuntimeError(f"PNG encoding failed: {e}") from e
......@@ -143,7 +143,7 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> Config:
def _default_endpoint(
namespace: str, modality: Modality, disaggregation_mode: DisaggregationMode
) -> str:
if modality == Modality.VIDEO_DIFFUSION:
if Modality.is_diffusion(modality):
component_name = DEFAULT_DIFFUSION_COMPONENT
elif disaggregation_mode == DisaggregationMode.ENCODE:
component_name = DEFAULT_ENCODE_COMPONENT
......
......@@ -238,65 +238,36 @@ class DynamoTrtllmArgGroup(ArgGroup):
"Options: xgrammar, llguidance.",
)
# --- Diffusion Options ---
self._add_diffusion_arguments(parser)
self._add_diffusion_request_arguments(parser)
def _add_diffusion_arguments(self, parser: argparse.ArgumentParser) -> None:
diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
)
add_argument(
diffusion_group,
flag_name="--default-height",
env_var="DYN_TRTLLM_DEFAULT_HEIGHT",
default=480,
arg_type=int,
help="Default video/image height in pixels.",
)
add_argument(
diffusion_group,
flag_name="--default-width",
env_var="DYN_TRTLLM_DEFAULT_WIDTH",
default=832,
arg_type=int,
help="Default video/image width in pixels.",
"Generic options for diffusion pipeline",
)
add_argument(
diffusion_group,
flag_name="--default-num-frames",
env_var="DYN_TRTLLM_DEFAULT_NUM_FRAMES",
default=81,
arg_type=int,
help="Default number of frames for video generation.",
)
add_argument(
diffusion_group,
flag_name="--default-num-inference-steps",
env_var="DYN_TRTLLM_DEFAULT_NUM_INFERENCE_STEPS",
default=50,
arg_type=int,
help="Default number of inference steps.",
)
add_argument(
diffusion_group,
flag_name="--default-guidance-scale",
env_var="DYN_TRTLLM_DEFAULT_GUIDANCE_SCALE",
default=5.0,
arg_type=float,
help="Default CFG guidance scale.",
)
add_argument(
diffusion_group,
flag_name="--torch-dtype",
env_var="DYN_TRTLLM_TORCH_DTYPE",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Torch dtype for model loading. bfloat16 recommended for Ampere+ GPUs.",
flag_name="--quant-algo",
env_var="DYN_TRTLLM_QUANT_ALGO",
default=None,
choices=[
"FP8",
"FP8_BLOCK_SCALES",
"NVFP4",
"W4A16_AWQ",
"W4A8_AWQ",
"W8A8_SQ_PER_CHANNEL",
],
help="Quantization algorithm for diffusion models. BF16 weights are quantized on-the-fly during loading.",
)
add_argument(
add_negatable_bool_argument(
diffusion_group,
flag_name="--revision",
env_var="DYN_TRTLLM_REVISION",
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA) for model download.",
flag_name="--quant-dynamic",
env_var="DYN_TRTLLM_QUANT_DYNAMIC",
default=True,
help="Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading).",
)
add_negatable_bool_argument(
diffusion_group,
......@@ -322,33 +293,18 @@ class DynamoTrtllmArgGroup(ArgGroup):
)
add_argument(
diffusion_group,
flag_name="--attn-backend",
env_var="DYN_TRTLLM_ATTN_BACKEND",
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend for diffusion models. VANILLA = PyTorch SDPA, TRTLLM = TensorRT-LLM kernels.",
flag_name="--torch-dtype",
env_var="DYN_TRTLLM_TORCH_DTYPE",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Torch dtype for model loading. bfloat16 recommended for Ampere+ GPUs.",
)
add_argument(
diffusion_group,
flag_name="--quant-algo",
env_var="DYN_TRTLLM_QUANT_ALGO",
flag_name="--revision",
env_var="DYN_TRTLLM_REVISION",
default=None,
choices=[
"FP8",
"FP8_BLOCK_SCALES",
"NVFP4",
"W4A16_AWQ",
"W4A8_AWQ",
"W8A8_SQ_PER_CHANNEL",
],
help="Quantization algorithm for diffusion models. BF16 weights are quantized on-the-fly during loading.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--quant-dynamic",
env_var="DYN_TRTLLM_QUANT_DYNAMIC",
default=True,
help="Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading).",
help="HuggingFace Hub revision (branch, tag, or commit SHA) for model download.",
)
add_negatable_bool_argument(
diffusion_group,
......@@ -364,13 +320,6 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="Enable torch.compile fullgraph mode (stricter but potentially faster).",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--fuse-qkv",
env_var="DYN_TRTLLM_FUSE_QKV",
default=True,
help="Enable QKV fusion for transformer attention layers.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-cuda-graph",
......@@ -378,19 +327,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="Enable CUDA graph capture for transformer forward passes. Mutually exclusive with torch.compile.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-layerwise-nvtx-marker",
env_var="DYN_TRTLLM_ENABLE_LAYERWISE_NVTX_MARKER",
default=False,
help="Enable per-layer NVTX markers for profiling with Nsight Systems.",
)
add_negatable_bool_argument(
add_argument(
diffusion_group,
flag_name="--skip-warmup",
env_var="DYN_TRTLLM_SKIP_WARMUP",
default=False,
help="Skip warmup inference during initialization.",
flag_name="--attn-backend",
env_var="DYN_TRTLLM_ATTN_BACKEND",
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend for diffusion models. VANILLA = PyTorch SDPA, TRTLLM = TensorRT-LLM kernels.",
)
add_argument(
diffusion_group,
......@@ -440,6 +383,27 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=int,
help="FSDP size for DiT.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--fuse-qkv",
env_var="DYN_TRTLLM_FUSE_QKV",
default=True,
help="Enable QKV fusion for transformer attention layers.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-layerwise-nvtx-marker",
env_var="DYN_TRTLLM_ENABLE_LAYERWISE_NVTX_MARKER",
default=False,
help="Enable per-layer NVTX markers for profiling with Nsight Systems.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--skip-warmup",
env_var="DYN_TRTLLM_SKIP_WARMUP",
default=False,
help="Skip warmup inference during initialization.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-async-cpu-offload",
......@@ -459,6 +423,65 @@ class DynamoTrtllmArgGroup(ArgGroup):
),
)
def _add_diffusion_request_arguments(self, parser: argparse.ArgumentParser) -> None:
# Check TRTLLM's DiffusionRequest for list of fields, note that
# we only add the fields that can be set in request, otherwise we use
# TRTLLM's default values by not setting them at all.
diffusion_request_group = parser.add_argument_group(
"Diffusion Request Options [Experimental]",
"Options to set default values for video/image generation requests",
)
add_argument(
diffusion_request_group,
flag_name="--default-height",
env_var="DYN_TRTLLM_DEFAULT_HEIGHT",
default=480,
arg_type=int,
help="Default video/image height in pixels.",
)
add_argument(
diffusion_request_group,
flag_name="--default-width",
env_var="DYN_TRTLLM_DEFAULT_WIDTH",
default=832,
arg_type=int,
help="Default video/image width in pixels.",
)
add_argument(
diffusion_request_group,
flag_name="--default-num-inference-steps",
env_var="DYN_TRTLLM_DEFAULT_NUM_INFERENCE_STEPS",
default=50,
arg_type=int,
help="Default number of inference steps.",
)
add_argument(
diffusion_request_group,
flag_name="--default-guidance-scale",
env_var="DYN_TRTLLM_DEFAULT_GUIDANCE_SCALE",
default=5.0,
arg_type=float,
help="Default CFG guidance scale.",
)
# Video specific args
add_argument(
diffusion_request_group,
flag_name="--default-num-frames",
env_var="DYN_TRTLLM_DEFAULT_NUM_FRAMES",
default=81,
arg_type=int,
help="Default number of frames for video generation.",
)
# Image specific args
add_argument(
diffusion_request_group,
flag_name="--default-num-images-per-prompt",
env_var="DYN_TRTLLM_DEFAULT_NUM_IMAGES_PER_PROMPT",
default=1,
arg_type=int,
help="Default number of images per prompt for image generation.",
)
class DynamoTrtllmConfig(ConfigBase):
"""Configuration for Dynamo TRT-LLM backend-specific options."""
......@@ -495,6 +518,7 @@ class DynamoTrtllmConfig(ConfigBase):
default_height: int
default_width: int
default_num_frames: int
default_num_images_per_prompt: int
default_num_inference_steps: int
default_guidance_scale: float
torch_dtype: str
......
......@@ -16,8 +16,9 @@ Fields map to TensorRT-LLM's VisualGenArgs sub-configs:
- QuantConfig: quantization algorithm and dynamic flags
"""
import dataclasses
from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Optional
from dynamo.common.utils.namespace import get_worker_namespace
......@@ -64,6 +65,7 @@ class DiffusionConfig:
max_height: int = 4096
max_width: int = 4096
default_num_frames: int = 81
default_num_images_per_prompt: int = 1
default_fps: int = 24 # Used for both frame count calculation and video encoding
default_seconds: int = 4 # Default video duration when only fps is specified
default_num_inference_steps: int = 50
......@@ -117,6 +119,22 @@ class DiffusionConfig:
# "scheduler", "image_encoder", "image_processor"
skip_components: list[str] = field(default_factory=list)
@classmethod
def from_config(cls, config: Any, skip_components: list[str]) -> "DiffusionConfig":
"""Build a DiffusionConfig from a worker Config, mapping matching field names automatically.
Special cases:
- model_path ← config.model (field name differs)
- skip_components ← pre-parsed list (Config holds a raw comma-separated string)
- max_height, max_width, default_fps, default_seconds use DiffusionConfig defaults
(they are not exposed as CLI args in Config)
"""
field_names = {f.name for f in dataclasses.fields(cls)}
kwargs = {k: getattr(config, k) for k in field_names if hasattr(config, k)}
kwargs["model_path"] = config.model
kwargs["skip_components"] = skip_components
return cls(**kwargs)
def __str__(self) -> str:
return (
f"DiffusionConfig("
......@@ -129,6 +147,7 @@ class DiffusionConfig:
f"default_height={self.default_height}, "
f"default_width={self.default_width}, "
f"default_num_frames={self.default_num_frames}, "
f"default_num_images_per_prompt={self.default_num_images_per_prompt}, "
f"default_num_inference_steps={self.default_num_inference_steps}, "
f"enable_teacache={self.enable_teacache}, "
f"attn_backend={self.attn_backend}, "
......
......@@ -25,12 +25,13 @@ class Modality(Enum):
- TEXT: Text-only LLM (generates text tokens)
- MULTIMODAL: Vision-language LLM (understands images, generates text)
- VIDEO_DIFFUSION: Video generation from text (generates video files)
- IMAGE_DIFFUSION: Image generation from text (generates image files)
"""
TEXT = "text"
MULTIMODAL = "multimodal"
VIDEO_DIFFUSION = "video_diffusion"
# TODO: Add IMAGE_DIFFUSION support in follow-up PR
IMAGE_DIFFUSION = "image_diffusion"
@classmethod
def is_diffusion(cls, modality: "Modality") -> bool:
......@@ -42,7 +43,7 @@ class Modality(Enum):
Returns:
True if the modality is VIDEO_DIFFUSION.
"""
return modality == cls.VIDEO_DIFFUSION
return modality in (cls.VIDEO_DIFFUSION, cls.IMAGE_DIFFUSION)
@classmethod
def is_llm(cls, modality: "Modality") -> bool:
......
......@@ -39,6 +39,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Modalities that are currently supported
class DiffusionModality(str, Enum):
"""Output modality of a diffusion pipeline."""
......@@ -49,11 +50,18 @@ class DiffusionModality(str, Enum):
# Explicit mapping from TRT-LLM pipeline class names to their output modality.
# This replaces brittle substring matching and must be updated when new
# pipelines are registered in TRT-LLM's PIPELINE_REGISTRY.
# The list of supported pipelines can be found by searching @register_pipeline in TRT-LLM's visual_gen module
_PIPELINE_MODALITY_MAP: dict[str, DiffusionModality] = {
# Text-to-Video pipelines
"WanPipeline": DiffusionModality.VIDEO,
"LTX2Pipeline": DiffusionModality.VIDEO,
"LTX2TwoStagesPipeline": DiffusionModality.VIDEO,
# [gluo FIXME] Image-to-Video pipelines, should get it for free from
# text-to-video support once we connect image_reference.
"WanImageToVideoPipeline": DiffusionModality.VIDEO,
# Text-to-Image pipelines
"FluxPipeline": DiffusionModality.IMAGE,
"LTX2Pipeline": DiffusionModality.VIDEO,
"Flux2Pipeline": DiffusionModality.IMAGE,
}
# Default when the pipeline is not yet loaded or the class name is unknown.
......@@ -213,6 +221,7 @@ class DiffusionEngine:
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_images_per_prompt: int = 1,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
seed: Optional[int] = None,
......@@ -231,6 +240,7 @@ class DiffusionEngine:
height: Output height in pixels.
width: Output width in pixels.
num_frames: Number of frames to generate (for video).
num_images_per_prompt: Number of images to generate per prompt (for image).
num_inference_steps: Number of denoising steps.
guidance_scale: CFG guidance scale.
seed: Random seed for reproducibility.
......@@ -265,6 +275,7 @@ class DiffusionEngine:
height=height,
width=width,
num_frames=num_frames,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed if seed is not None else random.randint(0, 2**32 - 1),
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Video diffusion request handlers for TensorRT-LLM backend.
"""Diffusion request handlers for TensorRT-LLM backend.
This module provides handlers for video generation using diffusion models.
This module provides handlers for image and video generation using diffusion models.
"""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
from dynamo.trtllm.request_handlers.diffusion.image_handler import (
ImageGenerationHandler,
)
from dynamo.trtllm.request_handlers.diffusion.video_handler import (
VideoGenerationHandler,
)
__all__ = ["VideoGenerationHandler"]
__all__ = ["ImageGenerationHandler", "VideoGenerationHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Image generation request handler for TensorRT-LLM backend.
This handler processes image generation requests using diffusion models.
It handles MediaOutput from TensorRT-LLM's visual_gen pipelines, which
can contain video, image, and/or audio tensors depending on the model.
"""
import asyncio
import base64
import logging
import time
import uuid
from typing import Any, AsyncGenerator, Optional
from dynamo._core import Context
from dynamo.common.protocols.image_protocol import (
ImageData,
ImageNvExt,
NvCreateImageRequest,
NvImagesResponse,
)
from dynamo.common.storage import get_fs, upload_to_fs
from dynamo.common.utils.image_utils import encode_to_png_bytes
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
logger = logging.getLogger(__name__)
class ImageGenerationHandler(BaseGenerativeHandler):
"""Handler for image generation requests.
This handler receives generation requests, runs the diffusion pipeline
via DiffusionEngine, encodes the output to the appropriate media format,
and returns the media URL or base64-encoded data.
Supports MediaOutput with:
- video: logged as unsupported (use an image handler instead)
- image: torch.Tensor (H, W, 3) uint8
- audio: logged (future: mux into MP4)
Inherits from BaseGenerativeHandler to share the common interface with
LLM handlers.
"""
def __init__(
self,
engine: DiffusionEngine,
config: DiffusionConfig,
):
"""Initialize the handler.
Args:
engine: The DiffusionEngine instance.
config: Diffusion generation configuration.
"""
self.engine = engine
self.config = config
if not config.media_output_fs_url:
raise ValueError(
"media_output_fs_url must be set; use --media-output-fs-url or DYN_MEDIA_OUTPUT_FS_URL."
)
self.media_output_fs = get_fs(config.media_output_fs_url)
self.media_output_http_url = config.media_output_http_url
# Serialize pipeline access — the diffusion pipeline is not thread-safe
# (mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event
# loop stays free for health checks and signal handling.
self._generate_lock = asyncio.Lock()
def _parse_size(self, size: Optional[str]) -> tuple[int, int]:
"""Parse 'WxH' string to (width, height) tuple.
The API accepts size as a string (e.g., "832x480") to match the format
used by OpenAI's image generation API (/v1/images/generations).
This method converts that string to a (width, height) tuple for the engine.
Args:
size: Size string in 'WxH' format (e.g., '832x480').
Returns:
Tuple of (width, height).
Raises:
ValueError: If dimensions exceed configured max_width/max_height.
"""
if not size:
width, height = self.config.default_width, self.config.default_height
else:
try:
w, h = size.split("x")
width, height = int(w), int(h)
except (ValueError, AttributeError):
logger.warning(f"Invalid size format: {size}, using defaults")
width, height = self.config.default_width, self.config.default_height
# Validate dimensions to prevent OOM
self._validate_dimensions(width, height)
return width, height
def _validate_dimensions(self, width: int, height: int) -> None:
"""Validate that dimensions don't exceed configured limits.
Args:
width: Requested width in pixels.
height: Requested height in pixels.
Raises:
ValueError: If width or height exceeds the configured maximum.
"""
errors = []
if not (1 <= width <= self.config.max_width):
errors.append(f"width {width} must be in [1, {self.config.max_width}]")
if not (1 <= height <= self.config.max_height):
errors.append(f"height {height} must be in [1, {self.config.max_height}]")
if errors:
raise ValueError(
f"Requested dimensions out of range: {', '.join(errors)}. "
f"This is a safety check to prevent out-of-memory errors. "
f"To allow larger sizes, increase --max-width and/or --max-height."
)
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""Generate video/image from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Handles MediaOutput from the pipeline:
- video tensor → unsupported (raises error)
- image tensor → PNG
- audio tensor → unsupported (raises error)
Args:
request: Request dictionary with generation parameters.
context: Dynamo context for request tracking.
Yields:
Response dictionary with generated media data.
"""
start_time = time.time()
request_id = str(uuid.uuid4())
logger.debug(f"Received generation request: {request_id}")
# Parse request
req = NvCreateImageRequest(**request)
nvext = req.nvext or ImageNvExt()
# Parse parameters
width, height = self._parse_size(req.size)
if req.n is not None and req.n > 1:
raise ValueError(
f"Requested {req.n} images, but this handler currently supports n=1 only."
)
num_images_per_prompt = (
req.n if req.n is not None else self.config.default_num_images_per_prompt
)
num_inference_steps = (
nvext.num_inference_steps
if nvext.num_inference_steps is not None
else self.config.default_num_inference_steps
)
guidance_scale = (
nvext.guidance_scale
if nvext.guidance_scale is not None
else self.config.default_guidance_scale
)
logger.debug(
f"Request {request_id}: prompt='{req.prompt[:50]}...', "
f"size={width}x{height}, images={num_images_per_prompt}, steps={num_inference_steps}"
)
# Run generation in thread pool (blocking operation).
# Lock ensures only one request uses the pipeline at a time.
# TODO: Add cancellation support. This requires:
# 1. The pipeline to expose a cancellation hook in the denoising loop
# 2. Passing a cancellation token/event to engine.generate()
# 3. Checking context.cancelled() and propagating to the pipeline
async with self._generate_lock:
output = await asyncio.to_thread(
self.engine.generate,
prompt=req.prompt,
negative_prompt=nvext.negative_prompt,
height=height,
width=width,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=nvext.seed,
)
if output is None:
raise RuntimeError("Pipeline returned no output (MediaOutput is None)")
# Determine output format
response_format = req.response_format or "url"
# Encode media based on what the pipeline returned
if output.image is not None:
# MediaOutput.image is (B, H, W, C) uint8 since TRT-LLM rc9;
images = output.image
assert (
images.ndim == 4 and images.shape[3] == 3
), f"Expected image shape (B, H, W, C), got {images.shape}"
# [gluo FIXME] currently only take the first image but the protocol supports multiple images
# verify if TRT-LLM will generate multiple images, relax this constraint if that's the case
image_np = images[0].cpu().numpy()
logger.debug(
f"Request {request_id}: encoding image output "
f"(shape={image_np.shape}) to PNG"
)
image_bytes = await asyncio.to_thread(encode_to_png_bytes, image_np)
elif output.video is not None:
raise RuntimeError(
"Pipeline returned video-only output, but this handler "
"only supports image. Use a video generation handler instead."
)
# Log audio if present (unsupported)
elif output.audio is not None:
raise RuntimeError(
"Pipeline returned audio-only output, but this handler "
"only supports image. Use an audio generation handler instead."
)
else:
raise RuntimeError(
"Pipeline returned MediaOutput with no video or image or audio data. "
f"MediaOutput fields: video={output.video is not None}, "
f"image={output.image is not None}, audio={output.audio is not None}"
)
# Return media via URL or base64
if response_format == "url":
storage_path = f"images/{request_id}.png"
image_url = await upload_to_fs(
self.media_output_fs,
storage_path,
image_bytes,
self.media_output_http_url,
)
image_data = ImageData(url=image_url)
else:
b64_image = base64.b64encode(image_bytes).decode("utf-8")
image_data = ImageData(b64_json=b64_image)
inference_time = time.time() - start_time
response = NvImagesResponse(
created=int(time.time()),
data=[image_data],
)
logger.debug(f"Request {request_id} completed in {inference_time:.2f}s")
yield response.model_dump()
def cleanup(self) -> None:
"""Cleanup handler resources."""
logger.info("ImageGenerationHandler cleanup")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for video diffusion components.
Tests for Modality enum, DiffusionConfig, VideoGenerationHandler helpers,
video protocol types, and concurrency safety.
These tests do NOT require visual_gen, torch, or GPU - they test logic only.
"""
import asyncio
import threading
import time
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import torch
from dynamo.common.protocols.image_protocol import (
ImageData,
ImageNvExt,
NvCreateImageRequest,
NvImagesResponse,
)
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
pytestmark = [
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
# [gluo FIXME] many parts of the test are validated as part of test_trtllm_video_diffusion.py,
# we should have common test suite for diffusion and additional tests for different modalities.
# =============================================================================
# Part 1: Modality Enum Tests
# =============================================================================
# This part of the test has been covered in test_trtllm_video_diffusion.py
# =============================================================================
# Part 2: DiffusionConfig Tests
# =============================================================================
# This part of the test has been covered in test_trtllm_video_diffusion.py
# =============================================================================
# Part 3: VideoGenerationHandler Helper Tests
# =============================================================================
class MockDiffusionConfig:
"""Mock config for testing handler helpers without full DiffusionConfig."""
default_width: int = 832
default_height: int = 480
default_num_frames: int = 81
default_num_images_per_prompt: int = 1
default_fps: int = 24
default_seconds: int = 4
max_width: int = 4096
max_height: int = 4096
class TestImageHandlerParseSize:
"""Tests for ImageGenerationHandler._parse_size method.
We test the method logic by creating a minimal mock handler.
"""
def setup_method(self):
"""Set up mock handler for each test."""
# Import here to avoid issues if handler has complex imports
from dynamo.trtllm.request_handlers.diffusion.image_handler import (
ImageGenerationHandler,
)
# Create handler with mocked dependencies
self.handler = object.__new__(ImageGenerationHandler)
self.handler.config = MockDiffusionConfig()
def test_parse_size_valid(self):
"""Test valid 'WxH' string parsing."""
width, height = self.handler._parse_size("832x480")
assert width == 832
assert height == 480
def test_parse_size_different_dimensions(self):
"""Test parsing various dimension strings."""
assert self.handler._parse_size("1920x1080") == (1920, 1080)
assert self.handler._parse_size("640x360") == (640, 360)
assert self.handler._parse_size("1x1") == (1, 1)
def test_parse_size_none(self):
"""Test None returns defaults."""
width, height = self.handler._parse_size(None)
assert width == MockDiffusionConfig.default_width
assert height == MockDiffusionConfig.default_height
def test_parse_size_empty_string(self):
"""Test empty string returns defaults."""
width, height = self.handler._parse_size("")
assert width == MockDiffusionConfig.default_width
assert height == MockDiffusionConfig.default_height
def test_parse_size_invalid_format(self):
"""Test invalid format returns defaults with warning."""
# No 'x' separator
assert self.handler._parse_size("832480") == (832, 480)
# Only one number
assert self.handler._parse_size("832") == (832, 480)
# Non-numeric
assert self.handler._parse_size("widthxheight") == (832, 480)
# Trailing 'x'
assert self.handler._parse_size("832x") == (832, 480)
def test_parse_size_exceeds_max_width(self):
"""Test that width exceeding max_width raises ValueError."""
with pytest.raises(ValueError) as exc_info:
self.handler._parse_size("5000x480")
assert "width 5000 must be in [1, 4096]" in str(exc_info.value)
assert "safety check to prevent out-of-memory" in str(exc_info.value)
def test_parse_size_exceeds_max_height(self):
"""Test that height exceeding max_height raises ValueError."""
with pytest.raises(ValueError) as exc_info:
self.handler._parse_size("832x5000")
assert "height 5000 must be in [1, 4096]" in str(exc_info.value)
def test_parse_size_exceeds_both_dimensions(self):
"""Test that both dimensions exceeding raises ValueError with both errors."""
with pytest.raises(ValueError) as exc_info:
self.handler._parse_size("10000x10000")
error_msg = str(exc_info.value)
assert "width 10000 must be in [1, 4096]" in error_msg
assert "height 10000 must be in [1, 4096]" in error_msg
def test_parse_size_at_max_boundary(self):
"""Test that dimensions exactly at max are allowed."""
# Should not raise - exactly at limit
width, height = self.handler._parse_size("4096x4096")
assert width == 4096
assert height == 4096
# =============================================================================
# Part 4: Image Protocol Tests
# =============================================================================
class TestNvCreateImageRequest:
"""Tests for NvCreateImageRequest protocol type."""
def test_required_fields(self):
"""Test that prompt and model are required."""
req = NvCreateImageRequest(prompt="A cat", model="black-forest-labs/FLUX.1-dev")
assert req.prompt == "A cat"
assert req.model == "black-forest-labs/FLUX.1-dev"
def test_required_fields_missing_prompt(self):
"""Test that missing prompt raises validation error."""
with pytest.raises(Exception): # Pydantic ValidationError
NvCreateImageRequest(model="black-forest-labs/FLUX.1-dev") # type: ignore
def test_optional_fields_default_none(self):
"""Test that optional fields default to None."""
req = NvCreateImageRequest(prompt="A cat")
# [gluo NOTE] in protocol the model is optional, but actually Dynamo will
# fill with default value if not provided.
assert req.model is None
assert req.n is None
assert req.quality is None
assert req.style is None
assert req.user is None
assert req.moderation is None
assert req.input_reference is None
assert req.size is None
assert req.response_format is None
assert req.nvext is None
def test_full_request_valid(self):
"""Test a fully populated request with nvext."""
req = NvCreateImageRequest(
prompt="A majestic lion",
model="black-forest-labs/FLUX.1-dev",
size="1920x1080",
response_format="b64_json",
nvext=ImageNvExt(
annotations=["tag1", "tag2"],
negative_prompt="blurry, low quality",
num_inference_steps=30,
guidance_scale=7.5,
seed=42,
),
)
assert req.prompt == "A majestic lion"
assert req.model == "black-forest-labs/FLUX.1-dev"
assert req.size == "1920x1080"
assert req.response_format == "b64_json"
assert req.nvext.annotations == ["tag1", "tag2"]
assert req.nvext.negative_prompt == "blurry, low quality"
assert req.nvext.num_inference_steps == 30
assert req.nvext.guidance_scale == 7.5
assert req.nvext.seed == 42
class TestImageData:
"""Tests for ImageData protocol type."""
def test_url_only(self):
"""Test ImageData with URL only."""
data = ImageData(url="/tmp/image.png")
assert data.url == "/tmp/image.png"
assert data.b64_json is None
def test_b64_only(self):
"""Test ImageData with base64 only."""
data = ImageData(b64_json="SGVsbG8gV29ybGQ=")
assert data.url is None
assert data.b64_json == "SGVsbG8gV29ybGQ="
def test_both_fields(self):
"""Test ImageData with both fields (unusual but valid)."""
data = ImageData(url="/tmp/image.png", b64_json="SGVsbG8=")
assert data.url == "/tmp/image.png"
assert data.b64_json == "SGVsbG8="
def test_empty_defaults(self):
"""Test ImageData with no arguments."""
data = ImageData()
assert data.url is None
assert data.b64_json is None
class TestNvImagesResponse:
"""Tests for NvImagesResponse protocol type."""
def test_default_values(self):
"""Test default values for completed response."""
response = NvImagesResponse(
created=1234567890,
)
assert response.created == 1234567890
assert response.data == []
def test_with_image_data(self):
"""Test response with image data."""
image = ImageData(url="/tmp/output.png")
response = NvImagesResponse(
created=1234567890,
data=[image],
)
assert len(response.data) == 1
assert response.data[0].url == "/tmp/output.png"
def test_model_dump(self):
"""Test serialization with model_dump()."""
response = NvImagesResponse(
id="req-123",
created=1234567890,
data=[ImageData(url="/tmp/image.png")],
)
dumped = response.model_dump()
assert isinstance(dumped, dict)
assert dumped["created"] == 1234567890
assert len(dumped["data"]) == 1
assert dumped["data"][0]["url"] == "/tmp/image.png"
# =============================================================================
# Part 5: DiffusionEngine Unit Tests
# =============================================================================
# This part of the test has been covered in test_trtllm_video_diffusion.py
# =============================================================================
# Part 6: Concurrency Safety Tests
# =============================================================================
# [gluo NOTE] this part have been covered in test_trtllm_video_diffusion.py,
# but need sanity check with image generation as the handler is different.
# Could be merged once a base DiffusionHandler is introduced.
class ConcurrencyTracker:
"""Mock replacement for ``DiffusionEngine.generate()`` that records
the peak number of threads executing it simultaneously.
What it mocks:
``engine.generate(**kwargs)`` — the blocking GPU call inside
``VideoGenerationHandler``. The handler dispatches this via
``asyncio.to_thread()``, so each request runs ``generate()``
in a separate OS thread.
What it focuses on:
Detecting *concurrent* entry into ``generate()``. It does NOT
test correctness of generated frames, GPU memory, or CUDA
streams — only whether multiple threads overlap inside the call.
How it works:
1. On entry: atomically increment ``_active_count`` and update
the high-water mark ``max_concurrent``.
2. Sleep for ``sleep_seconds`` to hold the thread inside the
function, creating a window where other threads *would*
overlap if nothing serializes them.
3. On exit: atomically decrement ``_active_count``.
After the test, inspect ``max_concurrent``:
- 1 → accesses were serialized (lock is working).
- >1 → concurrent access occurred (lock is missing/broken).
"""
def __init__(self, sleep_seconds: float = 0.1):
self._active_count = 0
self._lock = threading.Lock()
self.max_concurrent = 0
self.sleep_seconds = sleep_seconds
def generate(self, **kwargs):
"""Mock engine.generate() that tracks concurrent access."""
with self._lock:
self._active_count += 1
if self._active_count > self.max_concurrent:
self.max_concurrent = self._active_count
# Hold the thread here to widen the overlap window. Without
# serialization, other threads will enter generate() during
# this sleep and bump _active_count above 1.
time.sleep(self.sleep_seconds)
with self._lock:
self._active_count -= 1
# Return a mock MediaOutput with a video tensor
return SimpleNamespace(
video=None,
image=torch.zeros((1, 64, 64, 3), dtype=torch.uint8),
audio=None,
)
class TestVideoHandlerConcurrency:
"""Verifies that ``ImageGenerationHandler`` serializes access to the
underlying ``engine.generate()`` call.
Why this matters:
The visual_gen pipeline is a global singleton with mutable state,
unprotected CUDA graph caches, and shared config objects. It is
NOT thread-safe. ``ImageGenerationHandler`` dispatches generate()
via ``asyncio.to_thread()``, which runs each request in a
separate OS thread. Without an ``asyncio.Lock`` guarding the
call, concurrent requests would enter generate() simultaneously
and corrupt shared pipeline state.
How the test works:
1. Wires a ``ConcurrencyTracker`` as the mock engine so that
each generate() call sleeps long enough for overlapping
threads to be observable.
2. Fires N requests concurrently with ``asyncio.gather()``,
each of which calls ``handler.generate()`` → ``asyncio.to_thread()``
→ ``tracker.generate()``.
3. Asserts ``tracker.max_concurrent == 1``: only one thread was
inside generate() at any point.
Why it works:
- ``asyncio.gather()`` schedules all coroutines on the same
event loop, so they all reach ``asyncio.to_thread()``
nearly simultaneously.
- Without the handler's ``asyncio.Lock``, each coroutine
immediately spawns a thread, and those threads overlap
inside ``tracker.generate()`` during the sleep window →
``max_concurrent > 1``.
- With the lock, only one coroutine enters the
``async with self._generate_lock`` block at a time; the
others suspend cooperatively on the event loop. So only
one thread is ever inside generate() → ``max_concurrent == 1``.
"""
def _make_handler(self):
"""Create a ImageGenerationHandler with mock engine and config."""
from dynamo.trtllm.request_handlers.diffusion.image_handler import (
ImageGenerationHandler,
)
tracker = ConcurrencyTracker(sleep_seconds=0.1)
mock_engine = MagicMock()
mock_engine.generate = tracker.generate
config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media",
default_fps=24,
default_seconds=4,
)
with patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.get_fs",
return_value=MagicMock(),
):
handler = ImageGenerationHandler(
engine=mock_engine,
config=config,
)
return handler, tracker
def _make_request(self):
"""Create a minimal valid image generation request dict."""
return {
"prompt": "a test image",
"model": "test-model",
}
async def _drain_generator(self, handler, request):
"""Run handler.generate() and drain the async generator."""
async for _ in handler.generate(request, MagicMock()):
pass
@pytest.mark.timeout(5)
def test_concurrent_requests_are_serialized(self):
"""Fires 3 concurrent requests and asserts only one thread enters
engine.generate() at a time (max_concurrent == 1).
If the asyncio.Lock in ImageGenerationHandler is removed, the 3
asyncio.to_thread() calls run in parallel OS threads, overlapping
inside the tracker's sleep window, and max_concurrent rises to 3.
"""
async def run():
handler, tracker = self._make_handler()
requests = [self._make_request() for _ in range(3)]
with patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.encode_to_png_bytes",
return_value=b"fake_image_bytes",
), patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.upload_to_fs",
return_value="http://fake/image.png",
):
await asyncio.gather(
*(self._drain_generator(handler, req) for req in requests)
)
return tracker
tracker = asyncio.run(run())
assert tracker.max_concurrent == 1, (
f"Expected max_concurrent=1 (serialized), got {tracker.max_concurrent}. "
"Pipeline was accessed concurrently — this would corrupt visual_gen state."
)
# =============================================================================
# Part 6: ImageGenerationHandler Response Format Tests
# =============================================================================
class TestImageHandlerResponseFormats:
"""Tests for ImageGenerationHandler generate() response format branching."""
def _make_handler(self):
"""Create a handler with mocked engine and fs."""
from dynamo.trtllm.request_handlers.diffusion.image_handler import (
ImageGenerationHandler,
)
mock_output = SimpleNamespace(
video=None,
image=torch.zeros((1, 64, 64, 3), dtype=torch.uint8),
audio=None,
)
mock_engine = MagicMock()
mock_engine.generate = MagicMock(return_value=mock_output)
config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media",
media_output_http_url="https://cdn.example.com/media",
default_fps=24,
default_seconds=4,
)
with patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.get_fs",
return_value=MagicMock(),
):
handler = ImageGenerationHandler(
engine=mock_engine,
config=config,
)
return handler
@pytest.mark.asyncio
async def test_url_response_format(self):
"""Test generate() with url response format calls upload_to_fs."""
handler = self._make_handler()
request = {
"prompt": "a test image",
"model": "test-model",
"response_format": "url",
}
with patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.encode_to_png_bytes",
return_value=b"fake_image_bytes",
), patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.upload_to_fs",
return_value="https://cdn.example.com/media/images/test.png",
) as mock_upload:
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
response = results[0]
assert len(response["data"]) == 1
assert (
response["data"][0]["url"]
== "https://cdn.example.com/media/images/test.png"
)
mock_upload.assert_called_once()
@pytest.mark.asyncio
async def test_b64_response_format(self):
"""Test generate() with b64_json response format returns base64 encoded image."""
handler = self._make_handler()
request = {
"prompt": "a test image",
"model": "test-model",
"response_format": "b64_json",
}
with patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.encode_to_png_bytes",
return_value=b"fake_image_bytes",
):
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
response = results[0]
assert len(response["data"]) == 1
assert response["data"][0]["b64_json"] is not None
assert response["data"][0].get("url") is None
# Verify valid base64
import base64
decoded = base64.b64decode(response["data"][0]["b64_json"])
assert decoded == b"fake_image_bytes"
@pytest.mark.asyncio
async def test_default_response_format_is_url(self):
"""Test that generate() defaults to url response format."""
handler = self._make_handler()
request = {
"prompt": "a test image",
"model": "test-model",
# No response_format specified
}
with patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.encode_to_png_bytes",
return_value=b"fake_image_bytes",
), patch(
"dynamo.trtllm.request_handlers.diffusion.image_handler.upload_to_fs",
return_value="https://cdn.example.com/media/images/test.png",
) as mock_upload:
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
# Default should be "url" format, so upload_to_fs should be called.
mock_upload.assert_called_once()
assert results[0]["data"][0]["url"] is not None
@pytest.mark.asyncio
async def test_error_response_on_failure(self):
"""
Test that generate() raises exception on engine failure. This is different from video generation.
In video generation where the error is embedded in the response, but in image generation,
the response doesn't contain the error, so the handler doesn't suppress it and let it propagate.
"""
handler = self._make_handler()
handler.engine.generate = MagicMock(side_effect=RuntimeError("GPU OOM"))
request = {
"prompt": "a test image",
"model": "test-model",
}
with pytest.raises(RuntimeError) as exc_info:
async for _ in handler.generate(request, MagicMock()):
pass
assert "GPU OOM" in str(exc_info.value)
......@@ -46,15 +46,20 @@ class TestModality:
"""Tests for the Modality enum and its helper methods."""
def test_modality_values_exist(self):
"""Test that TEXT, MULTIMODAL, and VIDEO_DIFFUSION exist."""
"""Test that TEXT, MULTIMODAL, VIDEO_DIFFUSION, and IMAGE_DIFFUSION exist."""
assert Modality.TEXT.value == "text"
assert Modality.MULTIMODAL.value == "multimodal"
assert Modality.VIDEO_DIFFUSION.value == "video_diffusion"
assert Modality.IMAGE_DIFFUSION.value == "image_diffusion"
def test_is_diffusion_true_for_video_diffusion(self):
"""Test that VIDEO_DIFFUSION returns True for is_diffusion."""
assert Modality.is_diffusion(Modality.VIDEO_DIFFUSION) is True
def test_is_diffusion_true_for_image_diffusion(self):
"""Test that IMAGE_DIFFUSION returns True for is_diffusion."""
assert Modality.is_diffusion(Modality.IMAGE_DIFFUSION) is True
def test_is_diffusion_false_for_text(self):
"""Test that TEXT returns False for is_diffusion."""
assert Modality.is_diffusion(Modality.TEXT) is False
......@@ -75,6 +80,10 @@ class TestModality:
"""Test that VIDEO_DIFFUSION returns False for is_llm."""
assert Modality.is_llm(Modality.VIDEO_DIFFUSION) is False
def test_is_llm_false_for_image_diffusion(self):
"""Test that IMAGE_DIFFUSION returns False for is_llm."""
assert Modality.is_llm(Modality.IMAGE_DIFFUSION) is False
# =============================================================================
# Part 2: DiffusionConfig Tests
......@@ -168,6 +177,7 @@ class MockDiffusionConfig:
default_width: int = 832
default_height: int = 480
default_num_frames: int = 81
default_num_images_per_prompt: int = 1
default_fps: int = 24
default_seconds: int = 4
max_width: int = 4096
......@@ -194,7 +204,7 @@ class TestVideHandlerParseSize:
def setup_method(self):
"""Set up mock handler for each test."""
# Import here to avoid issues if handler has complex imports
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
from dynamo.trtllm.request_handlers.diffusion.video_handler import (
VideoGenerationHandler,
)
......@@ -274,7 +284,7 @@ class TestVideoHandlerComputeNumFrames:
def setup_method(self):
"""Set up mock handler for each test."""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
from dynamo.trtllm.request_handlers.diffusion.video_handler import (
VideoGenerationHandler,
)
......@@ -643,7 +653,7 @@ class TestVideoHandlerConcurrency:
def _make_handler(self):
"""Create a VideoGenerationHandler with mock engine and config."""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
from dynamo.trtllm.request_handlers.diffusion.video_handler import (
VideoGenerationHandler,
)
......@@ -659,7 +669,7 @@ class TestVideoHandlerConcurrency:
)
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.get_fs",
"dynamo.trtllm.request_handlers.diffusion.video_handler.get_fs",
return_value=MagicMock(),
):
handler = VideoGenerationHandler(
......@@ -696,10 +706,10 @@ class TestVideoHandlerConcurrency:
requests = [self._make_request() for _ in range(3)]
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
"dynamo.trtllm.request_handlers.diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4_bytes",
), patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.upload_to_fs",
"dynamo.trtllm.request_handlers.diffusion.video_handler.upload_to_fs",
return_value="http://fake/video.mp4",
):
await asyncio.gather(
......@@ -726,7 +736,7 @@ class TestVideoHandlerResponseFormats:
def _make_handler(self):
"""Create a handler with mocked engine and fs."""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
from dynamo.trtllm.request_handlers.diffusion.video_handler import (
VideoGenerationHandler,
)
......@@ -746,7 +756,7 @@ class TestVideoHandlerResponseFormats:
)
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.get_fs",
"dynamo.trtllm.request_handlers.diffusion.video_handler.get_fs",
return_value=MagicMock(),
):
handler = VideoGenerationHandler(
......@@ -768,10 +778,10 @@ class TestVideoHandlerResponseFormats:
}
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
"dynamo.trtllm.request_handlers.diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4",
), patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.upload_to_fs",
"dynamo.trtllm.request_handlers.diffusion.video_handler.upload_to_fs",
return_value="https://cdn.example.com/media/videos/test.mp4",
) as mock_upload:
results = []
......@@ -800,7 +810,7 @@ class TestVideoHandlerResponseFormats:
}
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
"dynamo.trtllm.request_handlers.diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4_bytes",
):
results = []
......@@ -832,10 +842,10 @@ class TestVideoHandlerResponseFormats:
}
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
"dynamo.trtllm.request_handlers.diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4",
), patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.upload_to_fs",
"dynamo.trtllm.request_handlers.diffusion.video_handler.upload_to_fs",
return_value="https://cdn.example.com/media/videos/test.mp4",
) as mock_upload:
results = []
......
......@@ -6,7 +6,7 @@
This package contains worker initialization functions for different modalities:
- llm_worker: Text and multimodal LLM inference
- video_diffusion_worker: Video generation using diffusion models
- image_diffusion_worker: Image generation using diffusion models
The init_worker() function dispatches to the appropriate worker based on modality.
Note on import strategy:
......@@ -14,6 +14,9 @@ Note on import strategy:
- video_diffusion_worker is imported lazily because it depends on visual_gen,
an optional package only available on TensorRT-LLM's feat/visual_gen branch.
Eager import would break text/multimodal users who don't have it installed.
- image_diffusion_worker is imported lazily because it depends on visual_gen,
an optional package only available on TensorRT-LLM's feat/visual_gen branch.
Eager import would break text/multimodal users who don't have it installed.
"""
import asyncio
......@@ -61,7 +64,15 @@ async def init_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
return
# TODO: Add IMAGE_DIFFUSION support in follow-up PR
elif modality == Modality.IMAGE_DIFFUSION:
from dynamo.trtllm.workers.image_diffusion_worker import (
init_image_diffusion_worker,
)
await init_image_diffusion_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
return
raise ValueError(f"Unsupported diffusion modality: {modality}")
# LLM modalities (text, multimodal)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Image diffusion worker initialization for TensorRT-LLM backend.
This module handles the initialization and lifecycle of image generation
workers using diffusion models (Wan, Flux, Cosmos, etc.).
"""
import asyncio
import logging
from typing import Optional
from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
async def init_image_diffusion_worker(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None:
# [gluo TODO] this can be the same as video diffusion worker, just need to update the handler and model type
"""Initialize and run the image diffusion worker.
This function handles image_diffusion modality, loading the appropriate
diffusion model and serving image generation requests.
Args:
runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
"""
# Check tensorrt_llm visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM (tensorrt_llm._torch.visual_gen).
# Without this check, users would get a cryptic ImportError deep inside
# DiffusionEngine.initialize().
try:
import tensorrt_llm._torch.visual_gen # noqa: F401
except ImportError:
raise ImportError(
"Image diffusion requires TensorRT-LLM with visual_gen support.\n"
"The visual_gen module is at tensorrt_llm._torch.visual_gen.\n"
"Install TensorRT-LLM with AIGV support:\n"
" pip install tensorrt_llm\n"
"See: https://github.com/NVIDIA/TensorRT-LLM"
) from None
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.diffusion import ImageGenerationHandler
logging.info(f"Initializing image diffusion worker with config: {config}")
# Parse skip_components from comma-separated string to list
skip_components = (
[c.strip() for c in config.skip_components.split(",") if c.strip()]
if config.skip_components
else []
)
if not config.endpoint:
raise ValueError("endpoint must be configured for image diffusion worker")
# Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig.from_config(config, skip_components)
# Get the endpoint from the runtime
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
# Initialize the diffusion engine (auto-detects pipeline from model_index.json)
engine = DiffusionEngine(diffusion_config)
await engine.initialize()
# Create the request handler
handler = ImageGenerationHandler(engine, diffusion_config)
# Register the model with Dynamo's discovery system
model_name = config.served_model_name or config.model
# Use ModelType.Images for image generation
if not hasattr(ModelType, "Images"):
raise RuntimeError(
"ModelType.Images not available in dynamo-runtime. "
"Image diffusion requires a compatible dynamo-runtime version. "
"See docs/backends/trtllm/README.md for setup instructions."
)
model_type = ModelType.Images
logging.info(f"Registering model '{model_name}' with ModelType={model_type}")
# register_model is Dynamo's generic model registration function
await register_model(
ModelInput.Text,
model_type,
endpoint,
config.model,
model_name,
)
logging.info(f"Model registered, serving endpoint: {config.endpoint}")
# Serve the endpoint
try:
await endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
)
except asyncio.CancelledError:
logging.info("Endpoint serving cancelled")
except Exception as e:
logging.error(f"Error serving endpoint: {e}", exc_info=True)
raise
finally:
handler.cleanup()
engine.cleanup()
......@@ -50,7 +50,7 @@ async def init_video_diffusion_worker(
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.video_diffusion import VideoGenerationHandler
from dynamo.trtllm.request_handlers.diffusion import VideoGenerationHandler
logging.info(f"Initializing video diffusion worker with config: {config}")
......@@ -65,52 +65,7 @@ async def init_video_diffusion_worker(
raise ValueError("endpoint must be configured for video diffusion worker")
# Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig(
namespace=config.namespace,
component=config.component,
endpoint=config.endpoint,
discovery_backend=config.discovery_backend,
request_plane=config.request_plane,
event_plane=config.event_plane,
model_path=config.model,
served_model_name=config.served_model_name,
torch_dtype=config.torch_dtype,
revision=config.revision,
media_output_fs_url=config.media_output_fs_url,
media_output_http_url=config.media_output_http_url,
default_height=config.default_height,
default_width=config.default_width,
default_num_frames=config.default_num_frames,
default_num_inference_steps=config.default_num_inference_steps,
default_guidance_scale=config.default_guidance_scale,
# Pipeline optimization
disable_torch_compile=config.disable_torch_compile,
enable_fullgraph=config.enable_fullgraph,
fuse_qkv=config.fuse_qkv,
enable_cuda_graph=config.enable_cuda_graph,
enable_layerwise_nvtx_marker=config.enable_layerwise_nvtx_marker,
skip_warmup=config.skip_warmup,
# Attention
attn_backend=config.attn_backend,
# Quantization
quant_algo=config.quant_algo,
quant_dynamic=config.quant_dynamic,
# TeaCache
enable_teacache=config.enable_teacache,
teacache_use_ret_steps=config.teacache_use_ret_steps,
teacache_thresh=config.teacache_thresh,
# Parallelism
dit_dp_size=config.dit_dp_size,
dit_tp_size=config.dit_tp_size,
dit_ulysses_size=config.dit_ulysses_size,
dit_ring_size=config.dit_ring_size,
dit_cfg_size=config.dit_cfg_size,
dit_fsdp_size=config.dit_fsdp_size,
# Offloading
enable_async_cpu_offload=config.enable_async_cpu_offload,
# Component loading
skip_components=skip_components,
)
diffusion_config = DiffusionConfig.from_config(config, skip_components)
# Get the endpoint from the runtime
endpoint = runtime.endpoint(
......
......@@ -8,27 +8,34 @@ For general TensorRT-LLM features and configuration, see the [Reference Guide](t
---
Dynamo supports video generation using diffusion models through the `--modality video_diffusion` flag.
Dynamo supports video generation using diffusion models through the `--modality video_diffusion` flag and
image generation through `--modality image_diffusion` flag.
## Requirements
- **TensorRT-LLM with visual_gen**: The `visual_gen` module is part of TensorRT-LLM (`tensorrt_llm._torch.visual_gen`). Install TensorRT-LLM following the [official instructions](https://github.com/NVIDIA/TensorRT-LLM#installation).
- **imageio with ffmpeg**: Required for encoding generated frames to MP4 video:
- **dynamo-runtime with multimodal API**: The Dynamo runtime must include `ModelType.Videos` or `ModelType.Images` support. Ensure you're using a compatible version.
- **VIDEO diffusion: imageio with ffmpeg**: Required for encoding generated frames to MP4 video:
```bash
pip install imageio[ffmpeg]
```
- **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version.
## Supported Models
| Diffusers Pipeline | Description | Example Model |
|--------------------|-------------|---------------|
| `WanPipeline` | Wan 2.1/2.2 Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` |
| `FluxPipeline` | FLUX Text-to-Image | `black-forest-labs/FLUX.1-dev` |
The pipeline type is **auto-detected** from the model's `model_index.json` — no `--model-type` flag is needed.
## Quick Start
### Video Diffusion
#### Launch worker
```bash
python -m dynamo.trtllm \
--modality video_diffusion \
......@@ -36,7 +43,7 @@ python -m dynamo.trtllm \
--media-output-fs-url file:///tmp/dynamo_media
```
## API Endpoint
#### API Endpoint
Video generation uses the `/v1/videos` endpoint:
......@@ -54,19 +61,45 @@ curl -X POST http://localhost:8000/v1/videos \
}'
```
### Image Diffusion
#### Launch worker
```bash
python -m dynamo.trtllm \
--modality image_diffusion \
--model-path black-forest-labs/FLUX.1-dev \
--media-output-fs-url file:///tmp/dynamo_media
```
#### API Endpoint
Image generation uses the `/v1/images/generations` endpoint:
```bash
curl -X POST http://localhost:8000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cat playing piano",
"model": "black-forest-labs/FLUX.1-dev",
"size": "256x256"
}'
```
## Configuration Options
| Flag | Description | Default |
|------|-------------|---------|
| `--media-output-fs-url` | Filesystem URL for storing generated media | `file:///tmp/dynamo_media` |
| `--default-height` | Default video height | `480` |
| `--default-width` | Default video width | `832` |
| `--default-height` | Default image/video height | `480` |
| `--default-width` | Default image/video width | `832` |
| `--default-num-frames` | Default frame count | `81` |
| `--default-num-images-per-prompt` | Default number of images per prompt | `1` |
| `--enable-teacache` | Enable TeaCache optimization | `False` |
| `--disable-torch-compile` | Disable torch.compile | `False` |
## Limitations
- Video diffusion is experimental and not recommended for production use
- Only text-to-video is supported in this release (image-to-video planned)
- Diffusion is experimental and not recommended for production use
- Only text-to-video and text-to-image is supported in this release (image-to-video planned)
- Requires GPU with sufficient VRAM for the diffusion model
......@@ -34,9 +34,9 @@ For more details, see the [Request Cancellation Architecture](../../fault-tolera
Dynamo with the TensorRT-LLM backend supports multimodal models, enabling you to process both text and images (or pre-computed embeddings) in a single request. For detailed setup instructions, example requests, and best practices, see the [TensorRT-LLM Multimodal Guide](../../features/multimodal/multimodal-trtllm.md).
## Video Diffusion Support (Experimental)
## Diffusion Support (Experimental)
Dynamo supports video generation using diffusion models through TensorRT-LLM. For requirements, supported models, API usage, and configuration options, see the [Video Diffusion Guide](./trtllm-video-diffusion.md).
Dynamo supports video and image generation using diffusion models through TensorRT-LLM. For requirements, supported models, API usage, and configuration options, see the [Diffusion Guide](./trtllm-diffusion.md).
## Logits Processing
......
......@@ -14,7 +14,7 @@ Dynamo supports serving diffusion models across multiple backends, enabling gene
| Modality | vLLM-Omni | SGLang | TRT-LLM |
|----------|-----------|--------|---------|
| Text-to-Text | ✅ | ✅ | ❌ |
| Text-to-Image | ✅ | ✅ | |
| Text-to-Image | ✅ | ✅ | |
| Text-to-Video | ✅ | ✅ | ✅ |
| Image-to-Video | ✅ | ❌ | ❌ |
......@@ -26,5 +26,5 @@ For deployment guides, configuration, and examples for each backend:
- **[vLLM-Omni](../../backends/vllm/vllm-omni.md)**
- **[SGLang Diffusion](../../backends/sglang/sglang-diffusion.md)**
- **[TRT-LLM Diffusion](../../backends/trtllm/trtllm-video-diffusion.md)**
- **[TRT-LLM Diffusion](../../backends/trtllm/trtllm-diffusion.md)**
- **[FastVideo (custom worker)](fastvideo.md)**
......@@ -9,7 +9,7 @@ sidebar-title: FastVideo
This guide covers deploying [FastVideo](https://github.com/hao-ai-lab/FastVideo) text-to-video generation on Dynamo using a custom worker (`worker.py`) exposed through the `/v1/videos` endpoint.
> [!NOTE]
> Dynamo also supports diffusion through built-in backends: [SGLang Diffusion](../../backends/sglang/sglang-diffusion.md) (LLM diffusion, image, video), [vLLM-Omni](../../backends/vllm/vllm-omni.md) (text-to-image, text-to-video), and [TRT-LLM Video Diffusion](../../backends/trtllm/trtllm-video-diffusion.md). See the [Diffusion Overview](README.md) for the full support matrix.
> Dynamo also supports diffusion through built-in backends: [SGLang Diffusion](../../backends/sglang/sglang-diffusion.md) (LLM diffusion, image, video), [vLLM-Omni](../../backends/vllm/vllm-omni.md) (text-to-image, text-to-video), and [TRT-LLM Diffusion](../../backends/trtllm/trtllm-diffusion.md) (text-to-image, text-to-video). See the [Diffusion Overview](README.md) for the full support matrix.
## Overview
......@@ -282,5 +282,5 @@ The example source lives at [`examples/diffusers/`](https://github.com/ai-dynamo
- [vLLM-Omni Text-to-Image](../../backends/vllm/vllm-omni.md#text-to-image) — vLLM-Omni image generation
- [SGLang Video Generation](../../backends/sglang/sglang-diffusion.md#video-generation) — SGLang video generation worker
- [SGLang Image Diffusion](../../backends/sglang/sglang-diffusion.md#image-diffusion) — SGLang image diffusion worker
- [TRT-LLM Video Diffusion](../../backends/trtllm/trtllm-video-diffusion.md#quick-start) — TensorRT-LLM video diffusion quick start
- [TRT-LLM Diffusion](../../backends/trtllm/trtllm-diffusion.md#quick-start) — TensorRT-LLM diffusion quick start
- [Diffusion Overview](README.md) — Full backend support matrix
......@@ -132,7 +132,7 @@ navigation:
- page: SGLang Diffusion
path: backends/sglang/sglang-diffusion.md
- page: TRT-LLM Diffusion
path: backends/trtllm/trtllm-video-diffusion.md
path: backends/trtllm/trtllm-diffusion.md
- page: Chat Processor Options
path: agents/chat-processor-options.md
- page: Tool Calling
......@@ -207,8 +207,8 @@ navigation:
path: backends/trtllm/trtllm-examples.md
- page: Observability
path: backends/trtllm/trtllm-observability.md
- page: Video Diffusion (Experimental)
path: backends/trtllm/trtllm-video-diffusion.md
- page: Diffusion (Experimental)
path: backends/trtllm/trtllm-diffusion.md
- page: Known Issues and Mitigations
path: backends/trtllm/trtllm-known-issues.md
- page: vLLM
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Aggregated image diffusion serving with TensorRT-LLM backend.
# Uses FLUX.1-dev by default (1 GPU).
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
source "$SCRIPT_DIR/../../../common/gpu_utils.sh" # build_trtllm_override_args_with_mem
# Environment variables with defaults
export DYNAMO_HOME=${DYNAMO_HOME:-"/workspace"}
export MODEL_PATH=${MODEL_PATH:-"black-forest-labs/FLUX.2-klein-4B"}
export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"black-forest-labs/FLUX.2-klein-4B"}
export MEDIA_OUTPUT_FS_URL=${MEDIA_OUTPUT_FS_URL:-"file:///tmp/dynamo_media"}
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " -h, --help Show this help message"
echo ""
echo "Any additional options are passed through to dynamo.trtllm."
exit 0
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
# Build GPU memory JSON (returns bare JSON, no flag)
OVERRIDE_JSON=$(build_trtllm_override_args_with_mem)
# Add --override-engine-args if we have JSON
TRTLLM_OVERRIDE_ARGS=()
if [[ -n "$OVERRIDE_JSON" ]]; then
TRTLLM_OVERRIDE_ARGS=(--override-engine-args "$OVERRIDE_JSON")
fi
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
print_launch_banner --no-curl "Launching Image Diffusion Serving (1 GPU)" "$MODEL_PATH" "$HTTP_PORT" \
"Media URL: $MEDIA_OUTPUT_FS_URL"
print_curl_footer <<CURL
curl http://localhost:${HTTP_PORT}/v1/images/generations \\
-H 'Content-Type: application/json' \\
-d '{
"model": "${SERVED_MODEL_NAME}",
"prompt": "${EXAMPLE_PROMPT_VISUAL}",
"size": "256x256",
"nvext": {"num_inference_steps": 10, "seed": 42}
}'
CURL
# run frontend
python3 -m dynamo.frontend &
# run image diffusion worker
python3 -m dynamo.trtllm \
--model-path "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--modality image_diffusion \
--media-output-fs-url "$MEDIA_OUTPUT_FS_URL" \
"${TRTLLM_OVERRIDE_ARGS[@]}" \
"${EXTRA_ARGS[@]}" &
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment