Unverified Commit df2daadd authored by Yuewei Na's avatar Yuewei Na Committed by GitHub
Browse files

feat: add video diffusion support to TRTLLM backend (wan_t2v only) (#5926)


Signed-off-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Signed-off-by: default avatarYuewei Na <248773860+nv-yna@users.noreply.github.com>
Co-authored-by: default avatarYuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: default avatarTanmay Verma <tanmayv@nvidia.com>
parent 8707dc2c
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Configuration classes for TensorRT-LLM backend.
This module provides configuration dataclasses:
- DiffusionConfig: Configuration for diffusion model workers
"""
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
__all__ = ["DiffusionConfig"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Configuration for diffusion model workers.
This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers.
"""
import os
from dataclasses import dataclass
from typing import Optional
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
# Default model paths
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
@dataclass
class DiffusionConfig:
"""Configuration for diffusion model workers (video/image generation).
This configuration is used by DiffusionEngine and diffusion handlers.
It can be populated from command-line arguments in trtllm_utils.py.
"""
# Dynamo runtime config
namespace: str = DYN_NAMESPACE
component: str = "diffusion"
endpoint: str = "generate"
store_kv: str = "etcd"
request_plane: str = "tcp"
event_plane: str = "nats"
# Model config
model_path: str = DEFAULT_VIDEO_MODEL_PATH
served_model_name: Optional[str] = None
# torch_dtype for model loading. Options: "bfloat16", "float16", "float32"
# bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.)
# float16 can be used on older GPUs (V100, etc.)
torch_dtype: str = "bfloat16"
# Output config
output_dir: str = "/tmp/dynamo_videos"
# Default generation parameters
default_height: int = 480
default_width: int = 832
# Maximum allowed dimensions to prevent OOM. Can be increased if GPU has sufficient VRAM.
max_height: int = 4096
max_width: int = 4096
default_num_frames: int = 81
default_fps: int = 24 # Used for both frame count calculation and video encoding
default_seconds: int = 4 # Default video duration when only fps is specified
default_num_inference_steps: int = 50
default_guidance_scale: float = 5.0
# visual_gen optimization config
enable_teacache: bool = False
teacache_use_ret_steps: bool = True
teacache_thresh: float = 0.2
attn_type: str = "default"
linear_type: str = "default"
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Parallelism config (DiTParallelConfig)
dit_dp_size: int = 1
dit_tp_size: int = 1
dit_ulysses_size: int = 1
dit_ring_size: int = 1
dit_cfg_size: int = 1
dit_fsdp_size: int = 1
# CPU offload config
enable_async_cpu_offload: bool = False
visual_gen_block_cpu_offload_stride: int = 1
def __str__(self) -> str:
return (
f"DiffusionConfig("
f"namespace={self.namespace}, "
f"component={self.component}, "
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"output_dir={self.output_dir}, "
f"default_height={self.default_height}, "
f"default_width={self.default_width}, "
f"default_num_frames={self.default_num_frames}, "
f"default_num_inference_steps={self.default_num_inference_steps}, "
f"enable_teacache={self.enable_teacache}, "
f"attn_type={self.attn_type}, "
f"linear_type={self.linear_type}, "
f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})"
)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Constants for TensorRT-LLM backend.
This module defines enums and constants used throughout the trtllm module.
"""
from enum import Enum
class DisaggregationMode(Enum):
"""Disaggregation mode for LLM workers."""
AGGREGATED = "prefill_and_decode"
PREFILL = "prefill"
DECODE = "decode"
ENCODE = "encode"
class Modality(Enum):
"""Modality types for different generative models.
This enum determines which type of model and handler to use:
- TEXT: Text-only LLM (generates text tokens)
- MULTIMODAL: Vision-language LLM (understands images, generates text)
- VIDEO_DIFFUSION: Video generation from text (generates video files)
"""
TEXT = "text"
MULTIMODAL = "multimodal"
VIDEO_DIFFUSION = "video_diffusion"
# TODO: Add IMAGE_DIFFUSION support in follow-up PR
@classmethod
def is_diffusion(cls, modality: "Modality") -> bool:
"""Check if a modality is a diffusion modality.
Args:
modality: The modality to check.
Returns:
True if the modality is VIDEO_DIFFUSION.
"""
return modality == cls.VIDEO_DIFFUSION
@classmethod
def is_llm(cls, modality: "Modality") -> bool:
"""Check if a modality is an LLM modality.
Args:
modality: The modality to check.
Returns:
True if the modality is TEXT or MULTIMODAL.
"""
return modality in (cls.TEXT, cls.MULTIMODAL)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Engine modules for TensorRT-LLM backend.
This module provides engine wrappers for various generative models:
- DiffusionEngine: Generic wrapper for visual_gen diffusion pipelines
"""
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
__all__ = ["DiffusionEngine"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Generic Diffusion Engine wrapper for visual_gen pipelines.
This module provides a unified interface for various diffusion models
(Wan, Flux, Cosmos, etc.) through a pipeline registry system.
The pipeline type is auto-detected from model_index.json (shipped with every
HuggingFace Diffusers model), eliminating the need for a --model-type flag.
Requirements:
- visual_gen: Part of TensorRT-LLM, located at tensorrt_llm/visual_gen/.
Currently on the feat/visual_gen branch (not yet merged to main).
See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen
- See docs/backends/trtllm/README.md for setup instructions.
Note on imports:
visual_gen is imported lazily in initialize() because:
1. It's a heavy package that may not be installed in all environments
2. Importing at module load would fail if visual_gen is not available
3. This allows the module to be imported for type checking and validation
without requiring visual_gen to be installed
"""
import importlib
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import torch
if TYPE_CHECKING:
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
logger = logging.getLogger(__name__)
@dataclass
class PipelineInfo:
"""Auto-detected pipeline information from model_index.json."""
module_path: str
class_name: str
modalities: list[str]
config_overrides: dict[str, Any]
class DiffusionEngine:
"""Generic wrapper for visual_gen diffusion pipelines.
This engine provides:
- Auto-detection of pipeline class from model_index.json
- A registry mapping diffusers _class_name to visual_gen pipelines
- Lazy loading of pipeline modules
- Common interface for video/image generation
Example:
>>> engine = DiffusionEngine(config)
>>> await engine.initialize()
>>> frames = engine.generate(prompt="A cat playing piano", ...)
"""
# Registry: diffusers _class_name -> (module_path, visual_gen_class, supported_modalities)
# The _class_name comes from model_index.json shipped with every HF Diffusers model.
# torch_compile_models is derived dynamically from transformer* keys in model_index.json.
#
# NOTE: This registry is initially focused on Wan text-to-video models.
# Follow-up PRs will extend support for other model families (Flux, Cosmos, etc.)
# which may require additional config fields in DiffusionConfig.
PIPELINE_REGISTRY: dict[str, tuple[str, str, list[str]]] = {
"WanPipeline": (
"visual_gen.pipelines.wan_pipeline",
"ditWanPipeline",
["video_diffusion"],
),
# TODO: Add support for WanImageToVideoPipeline, FluxPipeline, etc.
}
@classmethod
def detect_pipeline_info(cls, model_path: str) -> PipelineInfo:
"""Auto-detect pipeline class from model's model_index.json.
Reads model_index.json (local path or HuggingFace Hub) to determine:
- Which visual_gen pipeline class to use (via _class_name)
- Which transformer models to torch.compile (via transformer* keys)
Args:
model_path: Local path or HuggingFace model identifier.
Returns:
PipelineInfo with module_path, class_name, modalities, and config_overrides.
Raises:
ValueError: If _class_name is not in the registry.
FileNotFoundError: If model_index.json cannot be found locally or on HF Hub.
"""
# Try local path first
local_index = Path(model_path) / "model_index.json"
if local_index.exists():
with open(local_index) as f:
model_index = json.load(f)
else:
# Download from HuggingFace Hub
from huggingface_hub import hf_hub_download
index_path = hf_hub_download(model_path, "model_index.json")
with open(index_path) as f:
model_index = json.load(f)
class_name = model_index.get("_class_name")
if class_name not in cls.PIPELINE_REGISTRY:
supported = list(cls.PIPELINE_REGISTRY.keys())
raise ValueError(
f"Unsupported diffusion pipeline '{class_name}' from model '{model_path}'.\n"
f"Supported pipelines: {', '.join(supported)}\n"
f"Check that model_index.json has a supported _class_name."
)
module_path, vg_class, modalities = cls.PIPELINE_REGISTRY[class_name]
# Derive torch_compile_models from transformer* keys in model_index.json
transformer_keys = sorted(k for k in model_index if k.startswith("transformer"))
torch_compile_models = (
",".join(transformer_keys) if transformer_keys else "transformer"
)
config_overrides = {"torch_compile_models": torch_compile_models}
return PipelineInfo(module_path, vg_class, modalities, config_overrides)
def __init__(self, config: "DiffusionConfig"):
"""Initialize the engine with configuration.
Auto-detects the pipeline type from config.model_path's model_index.json.
Args:
config: Diffusion generation configuration.
Raises:
ValueError: If the model's pipeline type is not supported.
"""
info = self.detect_pipeline_info(config.model_path)
self.config = config
self._pipeline = None
self._initialized = False
self._module_path = info.module_path
self._class_name = info.class_name
self._supported_modalities = info.modalities
self._config_overrides = info.config_overrides
async def initialize(self) -> None:
"""Load and configure the diffusion pipeline.
This is called once at worker startup to load the model.
The specific pipeline class is determined by the auto-detected pipeline type.
"""
if self._initialized:
logger.warning("Engine already initialized, skipping")
return
logger.info(
f"Initializing DiffusionEngine: pipeline={self._class_name}, "
f"model_path={self.config.model_path}"
)
# Import visual_gen setup
from visual_gen import setup_configs
# Build configuration dict based on model type
dit_configs = self._build_dit_configs()
logger.info(f"dit_configs: {dit_configs}")
# Setup global configuration (required before pipeline loading)
setup_configs(**dit_configs)
# Dynamically import the pipeline class
logger.info(f"Importing pipeline from {self._module_path}.{self._class_name}")
module = importlib.import_module(self._module_path)
pipeline_class = getattr(module, self._class_name)
# Load the pipeline
# Convert torch_dtype string to actual torch dtype
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(self.config.torch_dtype, torch.bfloat16)
logger.info(
f"Loading pipeline from {self.config.model_path} with dtype={self.config.torch_dtype}"
)
self._pipeline = pipeline_class.from_pretrained(
self.config.model_path,
torch_dtype=torch_dtype,
**dit_configs,
)
# Move to target device
# NOTE: HuggingFace's from_pretrained() loads to CPU by default,
# so we must explicitly move to GPU for optimal performance.
if self.device == "cuda":
logger.info("Moving pipeline to GPU...")
self._pipeline.to(self.device)
logger.info("Pipeline moved to GPU successfully")
else:
logger.info("CPU offload enabled, pipeline stays on CPU")
self._initialized = True
logger.info(f"DiffusionEngine initialization complete: {self._class_name}")
def _build_dit_configs(self) -> dict[str, Any]:
"""Build dit_configs dict from DiffusionConfig.
Returns:
Configuration dictionary for visual_gen's setup_configs.
"""
# Get torch_compile_models from auto-detected config overrides
# Each pipeline in PIPELINE_REGISTRY specifies its required settings
torch_compile_models = self._config_overrides.get(
"torch_compile_models", "transformer"
)
return {
"pipeline": {
"enable_torch_compile": not self.config.disable_torch_compile,
"torch_compile_models": torch_compile_models,
"torch_compile_mode": self.config.torch_compile_mode,
"fuse_qkv": True,
},
"attn": {
"type": self.config.attn_type,
},
"linear": {
"type": self.config.linear_type,
"recipe": "dynamic",
},
"parallel": {
"disable_parallel_vae": False,
"parallel_vae_split_dim": "width",
"dit_dp_size": self.config.dit_dp_size,
"dit_tp_size": self.config.dit_tp_size,
"dit_ulysses_size": self.config.dit_ulysses_size,
"dit_ring_size": self.config.dit_ring_size,
"dit_cp_size": 1,
"dit_cfg_size": self.config.dit_cfg_size,
"dit_fsdp_size": self.config.dit_fsdp_size,
"t5_fsdp_size": 1,
},
"teacache": {
"enable_teacache": self.config.enable_teacache,
"use_ret_steps": self.config.teacache_use_ret_steps,
"teacache_thresh": self.config.teacache_thresh,
"ret_steps": 0,
"cutoff_steps": self.config.default_num_inference_steps,
},
}
def generate(
self,
prompt: str,
negative_prompt: Optional[str] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
seed: Optional[int] = None,
) -> np.ndarray:
"""Generate video/image frames from text prompt.
This is a synchronous method that should be called from a thread pool
to avoid blocking the event loop.
Args:
prompt: Text description of the content to generate.
negative_prompt: Text to avoid in the generation.
height: Output height in pixels.
width: Output width in pixels.
num_frames: Number of frames to generate (for video).
num_inference_steps: Number of denoising steps.
guidance_scale: CFG guidance scale.
seed: Random seed for reproducibility.
Returns:
numpy array of shape (num_frames, height, width, 3) with uint8 values
for video, or (height, width, 3) for images.
Raises:
RuntimeError: If engine not initialized or generation fails.
"""
if not self._initialized or self._pipeline is None:
raise RuntimeError("Engine not initialized. Call initialize() first.")
logger.info(
f"Generating: prompt='{prompt[:50]}...', "
f"size={width}x{height}, frames={num_frames}, steps={num_inference_steps}"
)
# Create generator for reproducibility
# Device must match pipeline device (CPU if offload enabled, CUDA otherwise)
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
# Run the pipeline
with torch.no_grad():
result = self._pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="np", # Return numpy array
)
# result.frames[0] is numpy array (num_frames, height, width, 3) uint8
frames = result.frames[0]
logger.info(f"Generated output with shape {frames.shape}")
return frames
def cleanup(self) -> None:
"""Cleanup resources."""
if self._pipeline is not None:
del self._pipeline
self._pipeline = None
self._initialized = False
if self.device == "cuda":
torch.cuda.empty_cache()
logger.info(f"DiffusionEngine cleanup complete: {self._class_name}")
@property
def is_initialized(self) -> bool:
"""Check if the engine is initialized."""
return self._initialized
@property
def supported_modalities(self) -> list[str]:
"""Get the modalities supported by this engine's model type."""
return self._supported_modalities
@property
def device(self) -> str:
"""Get the device where the pipeline runs.
Returns:
"cpu" if CPU offload is enabled, "cuda" otherwise.
"""
return "cpu" if self.config.enable_async_cpu_offload else "cuda"
......@@ -52,16 +52,17 @@ from dynamo.llm import (
)
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.constants import DisaggregationMode, Modality
from dynamo.trtllm.engine import Backend, TensorRTLLMEngine, get_llm_engine
from dynamo.trtllm.health_check import TrtllmHealthCheckPayload
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import DYNAMO_COMPONENT_REGISTRY, get_publisher
from dynamo.trtllm.request_handlers.handler_base import DisaggregationMode
from dynamo.trtllm.request_handlers.handlers import (
RequestHandlerConfig,
RequestHandlerFactory,
)
from dynamo.trtllm.utils.trtllm_utils import Config, cmd_line_args, deep_update
from dynamo.trtllm.workers import init_video_diffusion_worker
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
......@@ -135,10 +136,39 @@ async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Instantiate and serve
Instantiate and serve based on modality.
For video_diffusion modality, delegates to the video diffusion worker.
For text/multimodal, uses the LLM worker.
"""
logging.info(f"Initializing the worker with config: {config}")
# Check modality and dispatch to appropriate worker
modality = Modality(config.modality)
if Modality.is_diffusion(modality):
if modality == Modality.VIDEO_DIFFUSION:
await init_video_diffusion_worker(runtime, config, shutdown_event)
return
# TODO: Add IMAGE_DIFFUSION support in follow-up PR
# LLM modalities (text, multimodal) use the existing init_llm_worker logic
await init_llm_worker(runtime, config, shutdown_event)
async def init_llm_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
) -> None:
"""Initialize and run the LLM worker.
This function handles text and multimodal LLM modalities using TensorRT-LLM.
Args:
runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
"""
encode_client = None
if config.encode_endpoint:
logging.info(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Protocol types for TensorRT-LLM backend.
This module provides protocol types for various modalities:
- video_protocol: NvCreateVideoRequest, NvVideosResponse for video generation
- image_protocol: (future) Protocol types for image generation
"""
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest,
NvVideosResponse,
VideoData,
)
__all__ = [
"NvCreateVideoRequest",
"NvVideosResponse",
"VideoData",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Protocol types for video generation.
These types match the Rust protocol types in lib/llm/src/protocols/openai/videos.rs
to ensure compatibility with the Dynamo HTTP frontend.
"""
from typing import Optional
from pydantic import BaseModel
class NvCreateVideoRequest(BaseModel):
"""Request for video generation (/v1/videos/generations endpoint).
Matches Rust NvCreateVideoRequest in lib/llm/src/protocols/openai/videos.rs.
"""
# Required fields
prompt: str
"""The text prompt for video generation."""
model: str
"""The model to use for video generation."""
# Optional fields
input_reference: Optional[str] = None
"""Optional input reference for I2V (image path/url)."""
seconds: Optional[int] = None
"""Duration in seconds (default: 4)."""
fps: Optional[int] = None
"""Frames per second (default: 24)."""
num_frames: Optional[int] = None
"""Number of frames to generate (overrides fps * seconds if set)."""
size: Optional[str] = None
"""Video size in WxH format (default: '832x480')."""
num_inference_steps: Optional[int] = None
"""Number of denoising steps (default: 50)."""
guidance_scale: Optional[float] = None
"""CFG guidance scale (default: 5.0)."""
negative_prompt: Optional[str] = None
"""Optional negative prompt."""
seed: Optional[int] = None
"""Random seed for reproducibility."""
user: Optional[str] = None
"""Optional user identifier."""
response_format: Optional[str] = None
"""Response format: 'url' or 'b64_json' (default: 'url')."""
class VideoData(BaseModel):
"""Video data in response.
Matches Rust VideoData in lib/llm/src/protocols/openai/videos.rs.
"""
url: Optional[str] = None
"""URL of the generated video (if response_format is 'url')."""
b64_json: Optional[str] = None
"""Base64-encoded video (if response_format is 'b64_json')."""
class NvVideosResponse(BaseModel):
"""Response structure for video generation.
Matches Rust NvVideosResponse in lib/llm/src/protocols/openai/videos.rs.
"""
id: str
"""Unique identifier for the response."""
object: str = "video"
"""Object type (always 'video')."""
model: str
"""Model used for generation."""
status: str = "completed"
"""Generation status."""
progress: int = 100
"""Progress percentage (0-100)."""
created: int
"""Unix timestamp of creation."""
data: list[VideoData] = []
"""List of generated videos."""
error: Optional[str] = None
"""Error message if generation failed."""
inference_time_s: Optional[float] = None
"""Inference time in seconds."""
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Base generative handler for all Dynamo handlers.
This module defines the minimal interface that all generative handlers must implement.
It provides a common base class for LLM, video diffusion, and image diffusion handlers.
"""
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator
from dynamo._core import Context
class BaseGenerativeHandler(ABC):
"""Minimal base class for all generative handlers (LLM, video, image).
All handlers in the Dynamo system should inherit from this class and
implement the generate() method. This ensures a consistent interface
for the endpoint serving infrastructure.
"""
@abstractmethod
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""Generate response from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Subclasses implement the specific generation logic for their modality.
Args:
request: Request dictionary with generation parameters.
context: Dynamo context for request tracking and cancellation.
Yields:
Response dictionaries. For streaming outputs, multiple dicts may be
yielded. For non-streaming outputs (like video), a single dict is yielded.
Raises:
NotImplementedError: If called on BaseGenerativeHandler directly.
"""
raise NotImplementedError
# Note: This yield is needed to make this an async generator
yield {} # pragma: no cover
......@@ -39,6 +39,7 @@ from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor
from dynamo.trtllm.publisher import Publisher
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
from dynamo.trtllm.utils.disagg_utils import (
DisaggregatedParams,
DisaggregatedParamsCodec,
......@@ -72,9 +73,16 @@ class RequestHandlerConfig:
encoder_cache_capacity_gb: float = 0 # Encoder cache capacity in GB
class HandlerBase:
class HandlerBase(BaseGenerativeHandler):
"""
Base class for request handlers.
Base class for LLM request handlers (text generation, multimodal LLM).
This class is dedicated to LLM-based generation using TensorRT-LLM engine.
For diffusion-based handlers (video, image), see VideoGenerationHandler
and ImageGenerationHandler which inherit directly from BaseGenerativeHandler.
Inherits from BaseGenerativeHandler to ensure a consistent interface
across all generative handlers (LLM, video, image).
"""
def __init__(self, config: RequestHandlerConfig):
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Video diffusion request handlers for TensorRT-LLM backend.
This module provides handlers for video generation using diffusion models.
"""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
__all__ = ["VideoGenerationHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Video generation request handler for TensorRT-LLM backend.
This handler processes video generation requests using diffusion models.
"""
import asyncio
import base64
import logging
import time
import uuid
from typing import Any, AsyncGenerator, Optional
from dynamo._core import Component, Context
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest,
NvVideosResponse,
VideoData,
)
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
from dynamo.trtllm.request_handlers.video_diffusion.video_utils import (
encode_to_mp4,
encode_to_mp4_bytes,
)
logger = logging.getLogger(__name__)
class VideoGenerationHandler(BaseGenerativeHandler):
"""Handler for video generation requests.
This handler receives video generation requests, runs the diffusion
pipeline via DiffusionEngine, encodes the output to MP4, and returns
the video URL or base64-encoded data.
Inherits from BaseGenerativeHandler to share the common interface with
LLM handlers.
"""
def __init__(
self,
component: Component,
engine: DiffusionEngine,
config: DiffusionConfig,
):
"""Initialize the handler.
Args:
component: The Dynamo runtime component.
engine: The DiffusionEngine instance.
config: Diffusion generation configuration.
"""
self.component = component
self.engine = engine
self.config = config
# Serialize pipeline access — visual_gen is not thread-safe (global
# singleton configs, mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event
# loop stays free for health checks and signal handling.
self._generate_lock = asyncio.Lock()
def _parse_size(self, size: Optional[str]) -> tuple[int, int]:
"""Parse 'WxH' string to (width, height) tuple.
The API accepts size as a string (e.g., "832x480") to match the format
used by OpenAI's image generation API (/v1/images/generations).
This method converts that string to a (width, height) tuple for the engine.
Args:
size: Size string in 'WxH' format (e.g., '832x480').
Returns:
Tuple of (width, height).
Raises:
ValueError: If dimensions exceed configured max_width/max_height.
"""
if not size:
width, height = self.config.default_width, self.config.default_height
else:
try:
w, h = size.split("x")
width, height = int(w), int(h)
except (ValueError, AttributeError):
logger.warning(f"Invalid size format: {size}, using defaults")
width, height = self.config.default_width, self.config.default_height
# Validate dimensions to prevent OOM
self._validate_dimensions(width, height)
return width, height
def _validate_dimensions(self, width: int, height: int) -> None:
"""Validate that dimensions don't exceed configured limits.
Args:
width: Requested width in pixels.
height: Requested height in pixels.
Raises:
ValueError: If width or height exceeds the configured maximum.
"""
errors = []
if width > self.config.max_width:
errors.append(f"width {width} exceeds max_width {self.config.max_width}")
if height > self.config.max_height:
errors.append(
f"height {height} exceeds max_height {self.config.max_height}"
)
if errors:
raise ValueError(
f"Requested dimensions too large: {', '.join(errors)}. "
f"This is a safety check to prevent out-of-memory errors. "
f"To allow larger sizes, increase --max-width and/or --max-height."
)
def _compute_num_frames(self, req: NvCreateVideoRequest) -> int:
"""Compute num_frames from request parameters.
Priority:
1. num_frames if explicitly set
2. seconds * fps
3. config defaults
Args:
req: The video generation request.
Returns:
Number of frames to generate.
"""
# Priority 1: Explicit num_frames takes precedence
if req.num_frames is not None:
return req.num_frames
# Priority 2: If user provided seconds and/or fps, calculate frame count
# Use config defaults for any unspecified value
seconds = (
req.seconds if req.seconds is not None else self.config.default_seconds
)
fps = req.fps if req.fps is not None else self.config.default_fps
computed = seconds * fps
# Priority 3: If user provided NEITHER seconds NOR fps, use config default
# This allows config.default_num_frames to take effect only when the user
# didn't specify any duration-related parameters
if req.seconds is None and req.fps is None:
return self.config.default_num_frames
# User provided at least one of (seconds, fps), so use computed value
return computed
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""Generate video from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Args:
request: Request dictionary with video generation parameters.
context: Dynamo context for request tracking.
Yields:
Response dictionary with generated video data.
"""
start_time = time.time()
request_id = str(uuid.uuid4())
logger.info(f"Received video generation request: {request_id}")
try:
# Parse request
req = NvCreateVideoRequest(**request)
# Parse parameters
width, height = self._parse_size(req.size)
num_frames = self._compute_num_frames(req)
num_inference_steps = (
req.num_inference_steps
if req.num_inference_steps is not None
else self.config.default_num_inference_steps
)
guidance_scale = (
req.guidance_scale
if req.guidance_scale is not None
else self.config.default_guidance_scale
)
logger.info(
f"Request {request_id}: prompt='{req.prompt[:50]}...', "
f"size={width}x{height}, frames={num_frames}, steps={num_inference_steps}"
)
# Run generation in thread pool (blocking operation).
# Lock ensures only one request uses the pipeline at a time.
# TODO: Add cancellation support. This requires:
# 1. visual_gen to expose a cancellation hook in the denoising loop
# 2. Passing a cancellation token/event to engine.generate()
# 3. Checking context.cancelled() and propagating to the pipeline
async with self._generate_lock:
frames = await asyncio.to_thread(
self.engine.generate,
prompt=req.prompt,
negative_prompt=req.negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=req.seed,
)
# Determine output format
response_format = req.response_format or "url"
fps = req.fps or self.config.default_fps
if response_format == "url":
# Encode to MP4 and save to file
output_path = await asyncio.to_thread(
encode_to_mp4,
frames,
self.config.output_dir,
request_id,
fps=fps,
)
video_data = VideoData(url=output_path)
else:
# Encode to base64
video_bytes = await asyncio.to_thread(
encode_to_mp4_bytes, frames, fps=fps
)
b64_video = base64.b64encode(video_bytes).decode("utf-8")
video_data = VideoData(b64_json=b64_video)
inference_time = time.time() - start_time
response = NvVideosResponse(
id=request_id,
object="video",
model=req.model,
status="completed",
progress=100,
created=int(time.time()),
data=[video_data],
inference_time_s=inference_time,
)
logger.info(f"Request {request_id} completed in {inference_time:.2f}s")
yield response.model_dump()
except Exception as e:
logger.error(f"Request {request_id} failed: {e}", exc_info=True)
inference_time = time.time() - start_time
error_response = NvVideosResponse(
id=request_id,
object="video",
model=request.get("model", "unknown"),
status="failed",
progress=0,
created=int(time.time()),
data=[],
error=str(e),
inference_time_s=inference_time,
)
yield error_response.model_dump()
def cleanup(self) -> None:
"""Cleanup handler resources."""
logger.info("VideoGenerationHandler cleanup")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Video encoding utilities for TensorRT-LLM video diffusion.
This module provides utilities for encoding numpy video frames to MP4 format.
"""
import io
import logging
import os
import numpy as np
logger = logging.getLogger(__name__)
def encode_to_mp4(
frames: np.ndarray,
output_dir: str,
request_id: str,
fps: int = 16,
) -> str:
"""Encode numpy frames to MP4 file.
Args:
frames: Video frames as numpy array of shape (num_frames, height, width, 3)
with uint8 values 0-255.
output_dir: Directory to save the output video.
request_id: Unique identifier for the request (used in filename).
fps: Frames per second for the output video.
Returns:
Path to the saved MP4 file.
Raises:
ImportError: If imageio is not available.
RuntimeError: If encoding fails.
"""
try:
import imageio.v3 as iio
except ImportError:
try:
import imageio as iio
except ImportError:
raise ImportError(
"imageio is required for video encoding. "
"Install with: pip install imageio[ffmpeg]"
)
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{request_id}.mp4")
logger.info(f"Encoding {len(frames)} frames to {output_path} at {fps} fps")
try:
# Use imageio to write MP4
# imageio.v3 API
if hasattr(iio, "imwrite"):
iio.imwrite(output_path, frames, fps=fps, codec="libx264")
else:
# Fall back to v2 API
writer = iio.get_writer(output_path, fps=fps, codec="libx264")
try:
for frame in frames:
writer.append_data(frame)
finally:
writer.close()
logger.info(f"Video saved to {output_path}")
return output_path
except Exception as e:
logger.error(f"Failed to encode video: {e}")
raise RuntimeError(f"Video encoding failed: {e}") from e
def encode_to_mp4_bytes(
frames: np.ndarray,
fps: int = 16,
) -> bytes:
"""Encode numpy frames to MP4 bytes (in-memory).
Args:
frames: Video frames as numpy array of shape (num_frames, height, width, 3)
with uint8 values 0-255.
fps: Frames per second for the output video.
Returns:
MP4 video as bytes.
Raises:
ImportError: If imageio is not available.
RuntimeError: If encoding fails.
"""
try:
import imageio.v3 as iio
except ImportError:
try:
import imageio as iio
except ImportError:
raise ImportError(
"imageio is required for video encoding. "
"Install with: pip install imageio[ffmpeg]"
)
logger.info(f"Encoding {len(frames)} frames to bytes at {fps} fps")
try:
# Use in-memory buffer
buffer = io.BytesIO()
# imageio can write to BytesIO with format hint
if hasattr(iio, "imwrite"):
# v3 API - write to buffer
iio.imwrite(buffer, frames, extension=".mp4", fps=fps, codec="libx264")
else:
# v2 API
writer = iio.get_writer(
buffer, format="FFMPEG", mode="I", fps=fps, codec="libx264"
)
try:
for frame in frames:
writer.append_data(frame)
finally:
writer.close()
video_bytes = buffer.getvalue()
logger.info(f"Encoded video to {len(video_bytes)} bytes")
return video_bytes
except Exception as e:
logger.error(f"Failed to encode video to bytes: {e}")
raise RuntimeError(f"Video encoding to bytes failed: {e}") from e
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for video diffusion components.
Tests for Modality enum, DiffusionConfig, VideoGenerationHandler helpers,
video protocol types, and concurrency safety.
These tests do NOT require visual_gen, torch, or GPU - they test logic only.
"""
import asyncio
import threading
import time
from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.constants import Modality
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest,
NvVideosResponse,
VideoData,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
# =============================================================================
# Part 1: Modality Enum Tests
# =============================================================================
class TestModality:
"""Tests for the Modality enum and its helper methods."""
def test_modality_values_exist(self):
"""Test that TEXT, MULTIMODAL, and VIDEO_DIFFUSION exist."""
assert Modality.TEXT.value == "text"
assert Modality.MULTIMODAL.value == "multimodal"
assert Modality.VIDEO_DIFFUSION.value == "video_diffusion"
def test_is_diffusion_true_for_video_diffusion(self):
"""Test that VIDEO_DIFFUSION returns True for is_diffusion."""
assert Modality.is_diffusion(Modality.VIDEO_DIFFUSION) is True
def test_is_diffusion_false_for_text(self):
"""Test that TEXT returns False for is_diffusion."""
assert Modality.is_diffusion(Modality.TEXT) is False
def test_is_diffusion_false_for_multimodal(self):
"""Test that MULTIMODAL returns False for is_diffusion."""
assert Modality.is_diffusion(Modality.MULTIMODAL) is False
def test_is_llm_true_for_text(self):
"""Test that TEXT returns True for is_llm."""
assert Modality.is_llm(Modality.TEXT) is True
def test_is_llm_true_for_multimodal(self):
"""Test that MULTIMODAL returns True for is_llm."""
assert Modality.is_llm(Modality.MULTIMODAL) is True
def test_is_llm_false_for_video_diffusion(self):
"""Test that VIDEO_DIFFUSION returns False for is_llm."""
assert Modality.is_llm(Modality.VIDEO_DIFFUSION) is False
# =============================================================================
# Part 2: DiffusionConfig Tests
# =============================================================================
class TestDiffusionConfig:
"""Tests for DiffusionConfig dataclass."""
def test_default_values(self):
"""Test that default values are set correctly."""
config = DiffusionConfig()
# Dynamo runtime defaults
assert config.namespace == "dynamo" # May be overridden by env var
assert config.component == "diffusion"
assert config.endpoint == "generate"
# Generation defaults
assert config.default_height == 480
assert config.default_width == 832
assert config.default_num_frames == 81
assert config.default_num_inference_steps == 50
assert config.default_guidance_scale == 5.0
# Model defaults
assert config.output_dir == "/tmp/dynamo_videos"
# Optimization defaults
assert config.enable_teacache is False
assert config.attn_type == "default"
assert config.linear_type == "default"
# Parallelism defaults
assert config.dit_dp_size == 1
assert config.dit_tp_size == 1
def test_custom_values(self):
"""Test that custom values override defaults."""
config = DiffusionConfig(
default_height=720,
default_width=1280,
default_num_frames=120,
enable_teacache=True,
dit_tp_size=2,
)
assert config.default_height == 720
assert config.default_width == 1280
assert config.default_num_frames == 120
assert config.enable_teacache is True
assert config.dit_tp_size == 2
def test_str_representation(self):
"""Test that __str__ includes key fields."""
config = DiffusionConfig(
model_path="test/model",
default_height=480,
)
str_repr = str(config)
assert "DiffusionConfig(" in str_repr
assert "model_path=test/model" in str_repr
assert "default_height=480" in str_repr
assert "dit_tp_size=" in str_repr
# =============================================================================
# Part 3: VideoGenerationHandler Helper Tests
# =============================================================================
class MockDiffusionConfig:
"""Mock config for testing handler helpers without full DiffusionConfig."""
default_width: int = 832
default_height: int = 480
default_num_frames: int = 81
default_fps: int = 24
default_seconds: int = 4
max_width: int = 4096
max_height: int = 4096
@dataclass
class MockVideoRequest:
"""Mock video request for testing _compute_num_frames."""
prompt: str = "test prompt"
model: str = "test-model"
num_frames: Optional[int] = None
seconds: Optional[int] = None
fps: Optional[int] = None
class TestVideHandlerParseSize:
"""Tests for VideoGenerationHandler._parse_size method.
We test the method logic by creating a minimal mock handler.
"""
def setup_method(self):
"""Set up mock handler for each test."""
# Import here to avoid issues if handler has complex imports
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
# Create handler with mocked dependencies
self.handler = object.__new__(VideoGenerationHandler)
self.handler.config = MockDiffusionConfig()
def test_parse_size_valid(self):
"""Test valid 'WxH' string parsing."""
width, height = self.handler._parse_size("832x480")
assert width == 832
assert height == 480
def test_parse_size_different_dimensions(self):
"""Test parsing various dimension strings."""
assert self.handler._parse_size("1920x1080") == (1920, 1080)
assert self.handler._parse_size("640x360") == (640, 360)
assert self.handler._parse_size("1x1") == (1, 1)
def test_parse_size_none(self):
"""Test None returns defaults."""
width, height = self.handler._parse_size(None)
assert width == MockDiffusionConfig.default_width
assert height == MockDiffusionConfig.default_height
def test_parse_size_empty_string(self):
"""Test empty string returns defaults."""
width, height = self.handler._parse_size("")
assert width == MockDiffusionConfig.default_width
assert height == MockDiffusionConfig.default_height
def test_parse_size_invalid_format(self):
"""Test invalid format returns defaults with warning."""
# No 'x' separator
assert self.handler._parse_size("832480") == (832, 480)
# Only one number
assert self.handler._parse_size("832") == (832, 480)
# Non-numeric
assert self.handler._parse_size("widthxheight") == (832, 480)
# Trailing 'x'
assert self.handler._parse_size("832x") == (832, 480)
def test_parse_size_exceeds_max_width(self):
"""Test that width exceeding max_width raises ValueError."""
with pytest.raises(ValueError) as exc_info:
self.handler._parse_size("5000x480")
assert "width 5000 exceeds max_width 4096" in str(exc_info.value)
assert "safety check to prevent out-of-memory" in str(exc_info.value)
def test_parse_size_exceeds_max_height(self):
"""Test that height exceeding max_height raises ValueError."""
with pytest.raises(ValueError) as exc_info:
self.handler._parse_size("832x5000")
assert "height 5000 exceeds max_height 4096" in str(exc_info.value)
def test_parse_size_exceeds_both_dimensions(self):
"""Test that both dimensions exceeding raises ValueError with both errors."""
with pytest.raises(ValueError) as exc_info:
self.handler._parse_size("10000x10000")
error_msg = str(exc_info.value)
assert "width 10000 exceeds max_width 4096" in error_msg
assert "height 10000 exceeds max_height 4096" in error_msg
def test_parse_size_at_max_boundary(self):
"""Test that dimensions exactly at max are allowed."""
# Should not raise - exactly at limit
width, height = self.handler._parse_size("4096x4096")
assert width == 4096
assert height == 4096
class TestVideoHandlerComputeNumFrames:
"""Tests for VideoGenerationHandler._compute_num_frames method."""
def setup_method(self):
"""Set up mock handler for each test."""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
self.handler = object.__new__(VideoGenerationHandler)
self.handler.config = MockDiffusionConfig()
def test_compute_num_frames_explicit(self):
"""Test that explicit num_frames takes priority."""
req = NvCreateVideoRequest(
prompt="test",
model="test-model",
num_frames=100,
seconds=10, # Should be ignored
fps=30, # Should be ignored
)
assert self.handler._compute_num_frames(req) == 100
def test_compute_num_frames_from_seconds_fps(self):
"""Test computation from seconds * fps."""
req = NvCreateVideoRequest(
prompt="test",
model="test-model",
seconds=4,
fps=24,
)
assert self.handler._compute_num_frames(req) == 96 # 4 * 24
def test_compute_num_frames_only_seconds(self):
"""Test seconds with default fps (24)."""
req = NvCreateVideoRequest(
prompt="test",
model="test-model",
seconds=5,
)
# seconds=5, default fps=24 -> 5 * 24 = 120
assert self.handler._compute_num_frames(req) == 120
def test_compute_num_frames_only_fps(self):
"""Test fps with default seconds (4)."""
req = NvCreateVideoRequest(
prompt="test",
model="test-model",
fps=30,
)
# default seconds=4, fps=30 -> 4 * 30 = 120
assert self.handler._compute_num_frames(req) == 120
def test_compute_num_frames_defaults(self):
"""Test all None uses config default."""
req = NvCreateVideoRequest(
prompt="test",
model="test-model",
)
assert (
self.handler._compute_num_frames(req)
== MockDiffusionConfig.default_num_frames
)
# =============================================================================
# Part 4: Video Protocol Tests
# =============================================================================
class TestNvCreateVideoRequest:
"""Tests for NvCreateVideoRequest protocol type."""
def test_required_fields(self):
"""Test that prompt and model are required."""
req = NvCreateVideoRequest(prompt="A cat", model="wan_t2v")
assert req.prompt == "A cat"
assert req.model == "wan_t2v"
def test_required_fields_missing_prompt(self):
"""Test that missing prompt raises validation error."""
with pytest.raises(Exception): # Pydantic ValidationError
NvCreateVideoRequest(model="wan_t2v") # type: ignore
def test_required_fields_missing_model(self):
"""Test that missing model raises validation error."""
with pytest.raises(Exception): # Pydantic ValidationError
NvCreateVideoRequest(prompt="A cat") # type: ignore
def test_optional_fields_default_none(self):
"""Test that optional fields default to None."""
req = NvCreateVideoRequest(prompt="A cat", model="wan_t2v")
assert req.size is None
assert req.seconds is None
assert req.fps is None
assert req.num_frames is None
assert req.num_inference_steps is None
assert req.guidance_scale is None
assert req.negative_prompt is None
assert req.seed is None
assert req.response_format is None
def test_full_request_valid(self):
"""Test a fully populated request."""
req = NvCreateVideoRequest(
prompt="A majestic lion",
model="wan_t2v",
size="1920x1080",
seconds=5,
fps=30,
num_frames=150,
num_inference_steps=30,
guidance_scale=7.5,
negative_prompt="blurry, low quality",
seed=42,
response_format="b64_json",
)
assert req.prompt == "A majestic lion"
assert req.model == "wan_t2v"
assert req.size == "1920x1080"
assert req.seconds == 5
assert req.fps == 30
assert req.num_frames == 150
assert req.num_inference_steps == 30
assert req.guidance_scale == 7.5
assert req.negative_prompt == "blurry, low quality"
assert req.seed == 42
assert req.response_format == "b64_json"
class TestVideoData:
"""Tests for VideoData protocol type."""
def test_url_only(self):
"""Test VideoData with URL only."""
data = VideoData(url="/tmp/video.mp4")
assert data.url == "/tmp/video.mp4"
assert data.b64_json is None
def test_b64_only(self):
"""Test VideoData with base64 only."""
data = VideoData(b64_json="SGVsbG8gV29ybGQ=")
assert data.url is None
assert data.b64_json == "SGVsbG8gV29ybGQ="
def test_both_fields(self):
"""Test VideoData with both fields (unusual but valid)."""
data = VideoData(url="/tmp/video.mp4", b64_json="SGVsbG8=")
assert data.url == "/tmp/video.mp4"
assert data.b64_json == "SGVsbG8="
def test_empty_defaults(self):
"""Test VideoData with no arguments."""
data = VideoData()
assert data.url is None
assert data.b64_json is None
class TestNvVideosResponse:
"""Tests for NvVideosResponse protocol type."""
def test_default_values(self):
"""Test default values for completed response."""
response = NvVideosResponse(
id="req-123",
model="wan_t2v",
created=1234567890,
)
assert response.id == "req-123"
assert response.object == "video"
assert response.model == "wan_t2v"
assert response.status == "completed"
assert response.progress == 100
assert response.created == 1234567890
assert response.data == []
assert response.error is None
def test_error_response(self):
"""Test error response structure."""
response = NvVideosResponse(
id="req-456",
model="wan_t2v",
created=1234567890,
status="failed",
progress=0,
error="Model failed to load",
)
assert response.status == "failed"
assert response.progress == 0
assert response.error == "Model failed to load"
def test_with_video_data(self):
"""Test response with video data."""
video = VideoData(url="/tmp/output.mp4")
response = NvVideosResponse(
id="req-789",
model="wan_t2v",
created=1234567890,
data=[video],
inference_time_s=42.5,
)
assert len(response.data) == 1
assert response.data[0].url == "/tmp/output.mp4"
assert response.inference_time_s == 42.5
def test_model_dump(self):
"""Test serialization with model_dump()."""
response = NvVideosResponse(
id="req-123",
model="wan_t2v",
created=1234567890,
data=[VideoData(url="/tmp/video.mp4")],
)
dumped = response.model_dump()
assert isinstance(dumped, dict)
assert dumped["id"] == "req-123"
assert dumped["object"] == "video"
assert dumped["model"] == "wan_t2v"
assert dumped["status"] == "completed"
assert len(dumped["data"]) == 1
assert dumped["data"][0]["url"] == "/tmp/video.mp4"
# =============================================================================
# Part 5: Concurrency Safety Tests
# =============================================================================
class ConcurrencyTracker:
"""Mock replacement for ``DiffusionEngine.generate()`` that records
the peak number of threads executing it simultaneously.
What it mocks:
``engine.generate(**kwargs)`` — the blocking GPU call inside
``VideoGenerationHandler``. The handler dispatches this via
``asyncio.to_thread()``, so each request runs ``generate()``
in a separate OS thread.
What it focuses on:
Detecting *concurrent* entry into ``generate()``. It does NOT
test correctness of generated frames, GPU memory, or CUDA
streams — only whether multiple threads overlap inside the call.
How it works:
1. On entry: atomically increment ``_active_count`` and update
the high-water mark ``max_concurrent``.
2. Sleep for ``sleep_seconds`` to hold the thread inside the
function, creating a window where other threads *would*
overlap if nothing serializes them.
3. On exit: atomically decrement ``_active_count``.
After the test, inspect ``max_concurrent``:
- 1 → accesses were serialized (lock is working).
- >1 → concurrent access occurred (lock is missing/broken).
"""
def __init__(self, sleep_seconds: float = 0.1):
self._active_count = 0
self._lock = threading.Lock()
self.max_concurrent = 0
self.sleep_seconds = sleep_seconds
def generate(self, **kwargs):
"""Mock engine.generate() that tracks concurrent access."""
with self._lock:
self._active_count += 1
if self._active_count > self.max_concurrent:
self.max_concurrent = self._active_count
# Hold the thread here to widen the overlap window. Without
# serialization, other threads will enter generate() during
# this sleep and bump _active_count above 1.
time.sleep(self.sleep_seconds)
with self._lock:
self._active_count -= 1
# Return fake frames (shape: [num_frames, H, W, C])
import numpy as np
return np.zeros((4, 64, 64, 3), dtype=np.uint8)
class TestVideoHandlerConcurrency:
"""Verifies that ``VideoGenerationHandler`` serializes access to the
underlying ``engine.generate()`` call.
Why this matters:
The visual_gen pipeline is a global singleton with mutable state,
unprotected CUDA graph caches, and shared config objects. It is
NOT thread-safe. ``VideoGenerationHandler`` dispatches generate()
via ``asyncio.to_thread()``, which runs each request in a
separate OS thread. Without an ``asyncio.Lock`` guarding the
call, concurrent requests would enter generate() simultaneously
and corrupt shared pipeline state.
How the test works:
1. Wires a ``ConcurrencyTracker`` as the mock engine so that
each generate() call sleeps long enough for overlapping
threads to be observable.
2. Fires N requests concurrently with ``asyncio.gather()``,
each of which calls ``handler.generate()`` → ``asyncio.to_thread()``
→ ``tracker.generate()``.
3. Asserts ``tracker.max_concurrent == 1``: only one thread was
inside generate() at any point.
Why it works:
- ``asyncio.gather()`` schedules all coroutines on the same
event loop, so they all reach ``asyncio.to_thread()``
nearly simultaneously.
- Without the handler's ``asyncio.Lock``, each coroutine
immediately spawns a thread, and those threads overlap
inside ``tracker.generate()`` during the sleep window →
``max_concurrent > 1``.
- With the lock, only one coroutine enters the
``async with self._generate_lock`` block at a time; the
others suspend cooperatively on the event loop. So only
one thread is ever inside generate() → ``max_concurrent == 1``.
"""
def _make_handler(self):
"""Create a VideoGenerationHandler with mock engine and config."""
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
tracker = ConcurrencyTracker(sleep_seconds=0.1)
mock_engine = MagicMock()
mock_engine.generate = tracker.generate
config = DiffusionConfig(
output_dir="/tmp/test_videos",
default_fps=24,
default_seconds=4,
)
handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine,
config=config,
)
return handler, tracker
def _make_request(self):
"""Create a minimal valid video generation request dict."""
return {
"prompt": "a test video",
"model": "test-model",
}
async def _drain_generator(self, handler, request):
"""Run handler.generate() and drain the async generator."""
async for _ in handler.generate(request, MagicMock()):
pass
def test_concurrent_requests_are_serialized(self):
"""Fires 3 concurrent requests and asserts only one thread enters
engine.generate() at a time (max_concurrent == 1).
If the asyncio.Lock in VideoGenerationHandler is removed, the 3
asyncio.to_thread() calls run in parallel OS threads, overlapping
inside the tracker's sleep window, and max_concurrent rises to 3.
"""
async def run():
handler, tracker = self._make_handler()
requests = [self._make_request() for _ in range(3)]
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4",
return_value="/tmp/test.mp4",
):
await asyncio.gather(
*(self._drain_generator(handler, req) for req in requests)
)
return tracker
tracker = asyncio.run(run())
assert tracker.max_concurrent == 1, (
f"Expected max_concurrent=1 (serialized), got {tracker.max_concurrent}. "
"Pipeline was accessed concurrently — this would corrupt visual_gen state."
)
......@@ -11,7 +11,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.config_dump import add_config_dump_args, register_encoder
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.trtllm import __version__
from dynamo.trtllm.request_handlers.handler_base import DisaggregationMode
from dynamo.trtllm.constants import DisaggregationMode, Modality
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
......@@ -23,7 +23,11 @@ DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate" # Prefill
DEFAULT_ENCODE_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.tensorrt_llm_encode.generate" # Encode workers
)
DEFAULT_DIFFUSION_ENDPOINT = (
f"dyn://{DYN_NAMESPACE}.diffusion.generate" # Diffusion workers
)
DEFAULT_MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
DEFAULT_DISAGGREGATION_MODE = DisaggregationMode.AGGREGATED
......@@ -68,6 +72,27 @@ class Config:
# Whether to enable NATS for KV events (derived from publish_events_and_metrics)
self.use_kv_events: bool = False
# Diffusion-specific config (only used when modality is video_diffusion or image_diffusion)
self.output_dir: str = "/tmp/dynamo_videos"
self.default_height: int = 480
self.default_width: int = 832
self.default_num_frames: int = 81
self.default_num_inference_steps: int = 50
self.default_guidance_scale: float = 5.0
self.enable_teacache: bool = False
self.teacache_thresh: float = 0.2
self.attn_type: str = "default"
self.linear_type: str = "default"
self.disable_torch_compile: bool = False
self.torch_compile_mode: str = "default"
self.dit_dp_size: int = 1
self.dit_tp_size: int = 1
self.dit_ulysses_size: int = 1
self.dit_ring_size: int = 1
self.dit_cfg_size: int = 1
self.dit_fsdp_size: int = 1
self.enable_async_cpu_offload: bool = False
def __str__(self) -> str:
return (
f"Config(namespace={self.namespace}, "
......@@ -103,7 +128,10 @@ class Config:
f"request_plane={self.request_plane}, "
f"event_plane={self.event_plane}, "
f"enable_local_indexer={self.enable_local_indexer}, "
f"use_kv_events={self.use_kv_events}"
f"use_kv_events={self.use_kv_events}, "
f"output_dir={self.output_dir}, "
f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})"
)
......@@ -243,8 +271,9 @@ def cmd_line_args():
"--modality",
type=str,
default="text",
choices=["text", "multimodal"],
help="Modality to use for the model. Default: text. Current supported modalities are image.",
choices=[m.value for m in Modality],
help="Modality to use for the model. Default: text. "
"Options: text (LLM), multimodal (VLM), video_diffusion.",
)
parser.add_argument(
"--encode-endpoint",
......@@ -333,6 +362,131 @@ def cmd_line_args():
help="Enable durable KV events using NATS JetStream instead of the local indexer. By default, local indexer is enabled for lower latency. Use this flag when you need durability and multi-replica router consistency. Requires NATS with JetStream enabled. Can also be set via DYN_DURABLE_KV_EVENTS=true env var.",
)
# Diffusion-specific options (only used when modality is video_diffusion or image_diffusion)
diffusion_group = parser.add_argument_group(
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
)
diffusion_group.add_argument(
"--output-dir",
type=str,
default="/tmp/dynamo_videos",
help="Directory to store generated videos/images. Default: /tmp/dynamo_videos",
)
diffusion_group.add_argument(
"--default-height",
type=int,
default=480,
help="Default video/image height in pixels. Default: 480",
)
diffusion_group.add_argument(
"--default-width",
type=int,
default=832,
help="Default video/image width in pixels. Default: 832",
)
diffusion_group.add_argument(
"--default-num-frames",
type=int,
default=81,
help="Default number of frames for video generation. Default: 81",
)
diffusion_group.add_argument(
"--default-num-inference-steps",
type=int,
default=50,
help="Default number of inference steps. Default: 50",
)
diffusion_group.add_argument(
"--default-guidance-scale",
type=float,
default=5.0,
help="Default CFG guidance scale. Default: 5.0",
)
diffusion_group.add_argument(
"--enable-teacache",
action="store_true",
help="Enable TeaCache optimization for faster generation.",
)
diffusion_group.add_argument(
"--teacache-thresh",
type=float,
default=0.2,
help="TeaCache threshold. Default: 0.2",
)
diffusion_group.add_argument(
"--attn-type",
type=str,
default="default",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"],
help="Attention type for diffusion models. Default: default",
)
diffusion_group.add_argument(
"--linear-type",
type=str,
default="default",
choices=[
"default",
"trtllm-fp8-blockwise",
"trtllm-fp8-per-tensor",
"trtllm-nvfp4",
],
help="Linear type for quantization. Default: default",
)
diffusion_group.add_argument(
"--disable-torch-compile",
action="store_true",
help="Disable torch.compile optimization.",
)
diffusion_group.add_argument(
"--torch-compile-mode",
type=str,
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode. Default: default",
)
diffusion_group.add_argument(
"--dit-dp-size",
type=int,
default=1,
help="Data parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-tp-size",
type=int,
default=1,
help="Tensor parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-ulysses-size",
type=int,
default=1,
help="Ulysses parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-ring-size",
type=int,
default=1,
help="Ring parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-cfg-size",
type=int,
default=1,
help="CFG parallel size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--dit-fsdp-size",
type=int,
default=1,
help="FSDP size for DiT. Default: 1",
)
diffusion_group.add_argument(
"--enable-async-cpu-offload",
action="store_true",
help="Enable async CPU offload for memory efficiency.",
)
args = parser.parse_args()
config = Config()
......@@ -344,12 +498,17 @@ def cmd_line_args():
# This becomes an `Option` on the Rust side
config.served_model_name = None
# Set modality
config.modality = args.modality
# Set the disaggregation mode.
config.disaggregation_mode = DisaggregationMode(args.disaggregation_mode)
# Set the appropriate default for the endpoint based on disaggregation mode
# Set the appropriate default for the endpoint based on modality and disaggregation mode
if args.endpoint == "":
if config.disaggregation_mode == DisaggregationMode.ENCODE:
if Modality(args.modality) == Modality.VIDEO_DIFFUSION:
args.endpoint = DEFAULT_DIFFUSION_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.ENCODE:
args.endpoint = DEFAULT_ENCODE_ENDPOINT
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
args.endpoint = DEFAULT_PREFILL_ENDPOINT
......@@ -387,7 +546,6 @@ def cmd_line_args():
config.extra_engine_args = args.extra_engine_args
config.override_engine_args = args.override_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.modality = args.modality
config.reasoning_parser = args.dyn_reasoning_parser
config.tool_call_parser = args.dyn_tool_call_parser
......@@ -415,6 +573,27 @@ def cmd_line_args():
else:
config.custom_jinja_template = None
# Copy diffusion-specific args (only relevant for video_diffusion/image_diffusion)
config.output_dir = args.output_dir
config.default_height = args.default_height
config.default_width = args.default_width
config.default_num_frames = args.default_num_frames
config.default_num_inference_steps = args.default_num_inference_steps
config.default_guidance_scale = args.default_guidance_scale
config.enable_teacache = args.enable_teacache
config.teacache_thresh = args.teacache_thresh
config.attn_type = args.attn_type
config.linear_type = args.linear_type
config.disable_torch_compile = args.disable_torch_compile
config.torch_compile_mode = args.torch_compile_mode
config.dit_dp_size = args.dit_dp_size
config.dit_tp_size = args.dit_tp_size
config.dit_ulysses_size = args.dit_ulysses_size
config.dit_ring_size = args.dit_ring_size
config.dit_cfg_size = args.dit_cfg_size
config.dit_fsdp_size = args.dit_fsdp_size
config.enable_async_cpu_offload = args.enable_async_cpu_offload
return config
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Worker initialization modules for TensorRT-LLM backend.
This package contains worker initialization functions for different modalities:
- video_diffusion_worker: Video generation using diffusion models
"""
from dynamo.trtllm.workers.video_diffusion_worker import init_video_diffusion_worker
__all__ = ["init_video_diffusion_worker"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Video diffusion worker initialization for TensorRT-LLM backend.
This module handles the initialization and lifecycle of video generation
workers using diffusion models (Wan, Flux, Cosmos, etc.).
"""
import asyncio
import logging
from dynamo.llm import ModelInput, ModelType, register_llm
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.utils.trtllm_utils import Config
async def init_video_diffusion_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
) -> None:
"""Initialize and run the video diffusion worker.
This function handles video_diffusion modality, loading the appropriate
diffusion model and serving video generation requests.
Args:
runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
"""
# Import diffusion-specific modules (lazy import to avoid loading heavy deps early)
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.video_diffusion import VideoGenerationHandler
logging.info(f"Initializing video diffusion worker with config: {config}")
# Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig(
namespace=config.namespace,
component=config.component,
endpoint=config.endpoint,
store_kv=config.store_kv,
request_plane=config.request_plane,
event_plane=config.event_plane,
model_path=config.model_path,
served_model_name=config.served_model_name,
output_dir=config.output_dir,
default_height=config.default_height,
default_width=config.default_width,
default_num_frames=config.default_num_frames,
default_num_inference_steps=config.default_num_inference_steps,
default_guidance_scale=config.default_guidance_scale,
enable_teacache=config.enable_teacache,
teacache_thresh=config.teacache_thresh,
attn_type=config.attn_type,
linear_type=config.linear_type,
disable_torch_compile=config.disable_torch_compile,
torch_compile_mode=config.torch_compile_mode,
dit_dp_size=config.dit_dp_size,
dit_tp_size=config.dit_tp_size,
dit_ulysses_size=config.dit_ulysses_size,
dit_ring_size=config.dit_ring_size,
dit_cfg_size=config.dit_cfg_size,
dit_fsdp_size=config.dit_fsdp_size,
enable_async_cpu_offload=config.enable_async_cpu_offload,
)
# Get the component and endpoint from the runtime
component = runtime.namespace(config.namespace).component(config.component)
endpoint = component.endpoint(config.endpoint)
# Initialize the diffusion engine (auto-detects pipeline from model_index.json)
engine = DiffusionEngine(diffusion_config)
await engine.initialize()
# Create the request handler
handler = VideoGenerationHandler(component, engine, diffusion_config)
# Register the model with Dynamo's discovery system
model_name = config.served_model_name or config.model_path
# Use ModelType.Videos for video generation
if not hasattr(ModelType, "Videos"):
raise RuntimeError(
"ModelType.Videos not available in dynamo-runtime. "
"Video diffusion requires a compatible dynamo-runtime version. "
"See docs/backends/trtllm/README.md for setup instructions."
)
model_type = ModelType.Videos
logging.info(f"Registering model '{model_name}' with ModelType={model_type}")
# register_llm is a misnomer — it's actually Dynamo's generic model
# registration function and the video diffisuion model is not an llm
await register_llm(
ModelInput.Text,
model_type,
endpoint,
config.model_path,
model_name,
)
logging.info(f"Model registered, serving endpoint: {config.endpoint}")
# Serve the endpoint
try:
await endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
)
except asyncio.CancelledError:
logging.info("Endpoint serving cancelled")
except Exception as e:
logging.error(f"Error serving endpoint: {e}", exc_info=True)
raise
finally:
handler.cleanup()
engine.cleanup()
......@@ -42,6 +42,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- [Client](#client)
- [Benchmarking](#benchmarking)
- [Multimodal Support](#multimodal-support)
- [Video Diffusion Support](#video-diffusion-support-experimental)
- [Logits Processing](#logits-processing)
- [DP Rank Routing](#dp-rank-routing-attention-data-parallelism)
- [Performance Sweep](#performance-sweep)
......@@ -220,6 +221,70 @@ To benchmark your deployment with AIPerf, see this utility script, configuring t
Dynamo with the TensorRT-LLM backend supports multimodal models, enabling you to process both text and images (or pre-computed embeddings) in a single request. For detailed setup instructions, example requests, and best practices, see the [TensorRT-LLM Multimodal Guide](../../features/multimodal/multimodal_trtllm.md).
## Video Diffusion Support (Experimental)
Dynamo supports video generation using diffusion models through the `--modality video_diffusion` flag.
### Requirements
- **visual_gen**: Part of TensorRT-LLM, located at `tensorrt_llm/visual_gen/`. Currently available **only** on the [`feat/visual_gen`](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch (not yet merged to main or any release). Install from source:
```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM && git checkout feat/visual_gen
cd tensorrt_llm/visual_gen && pip install -e .
```
- **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version.
### Supported Models
| Diffusers Pipeline | Description | Example Model |
|--------------------|-------------|---------------|
| `WanPipeline` | Wan 2.1/2.2 Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` |
The pipeline type is **auto-detected** from the model's `model_index.json` — no `--model-type` flag is needed.
### Quick Start
```bash
python -m dynamo.trtllm \
--modality video_diffusion \
--model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--output-dir /tmp/videos
```
### API Endpoint
Video generation uses the `/v1/videos/generations` endpoint:
```bash
curl -X POST http://localhost:8000/v1/videos/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "A cat playing piano",
"model": "wan_t2v",
"size": "832x480",
"seconds": 4,
"fps": 24
}'
```
### Configuration Options
| Flag | Description | Default |
|------|-------------|---------|
| `--output-dir` | Directory for generated videos | `/tmp/dynamo_videos` |
| `--default-height` | Default video height | `480` |
| `--default-width` | Default video width | `832` |
| `--default-num-frames` | Default frame count | `81` |
| `--enable-teacache` | Enable TeaCache optimization | `False` |
| `--disable-torch-compile` | Disable torch.compile | `False` |
### Limitations
- Video diffusion is experimental and not recommended for production use
- Only text-to-video is supported in this release (image-to-video planned)
- Requires GPU with sufficient VRAM for the diffusion model
## Logits Processing
Logits processors let you modify the next-token logits at every decoding step (e.g., to apply custom constraints or sampling transforms). Dynamo provides a backend-agnostic interface and an adapter for TensorRT-LLM so you can plug in custom processors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment