Unverified Commit 8de1c3e0 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Update TRTLLM video diffusion for `visual_gen` module restructuring (PR #11462) (#6516)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent 78436fbf
......@@ -242,6 +242,22 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=float,
help="Default CFG guidance scale.",
)
add_argument(
diffusion_group,
flag_name="--torch-dtype",
env_var="DYN_TRTLLM_TORCH_DTYPE",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Torch dtype for model loading. bfloat16 recommended for Ampere+ GPUs.",
)
add_argument(
diffusion_group,
flag_name="--revision",
env_var="DYN_TRTLLM_REVISION",
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA) for model download.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-teacache",
......@@ -249,6 +265,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="Enable TeaCache optimization for faster generation.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--teacache-use-ret-steps",
env_var="DYN_TRTLLM_TEACACHE_USE_RET_STEPS",
default=True,
help="Use retention steps for TeaCache.",
)
add_argument(
diffusion_group,
flag_name="--teacache-thresh",
......@@ -259,24 +282,33 @@ class DynamoTrtllmArgGroup(ArgGroup):
)
add_argument(
diffusion_group,
flag_name="--attn-type",
env_var="DYN_TRTLLM_ATTN_TYPE",
default="default",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"],
help="Attention type for diffusion models.",
flag_name="--attn-backend",
env_var="DYN_TRTLLM_ATTN_BACKEND",
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend for diffusion models. VANILLA = PyTorch SDPA, TRTLLM = TensorRT-LLM kernels.",
)
add_argument(
diffusion_group,
flag_name="--linear-type",
env_var="DYN_TRTLLM_LINEAR_TYPE",
default="default",
flag_name="--quant-algo",
env_var="DYN_TRTLLM_QUANT_ALGO",
default=None,
choices=[
"default",
"trtllm-fp8-blockwise",
"trtllm-fp8-per-tensor",
"trtllm-nvfp4",
"FP8",
"FP8_BLOCK_SCALES",
"NVFP4",
"W4A16_AWQ",
"W4A8_AWQ",
"W8A8_SQ_PER_CHANNEL",
],
help="Linear type for quantization.",
help="Quantization algorithm for diffusion models. BF16 weights are quantized on-the-fly during loading.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--quant-dynamic",
env_var="DYN_TRTLLM_QUANT_DYNAMIC",
default=True,
help="Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading).",
)
add_negatable_bool_argument(
diffusion_group,
......@@ -293,6 +325,42 @@ class DynamoTrtllmArgGroup(ArgGroup):
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-fullgraph",
env_var="DYN_TRTLLM_ENABLE_FULLGRAPH",
default=False,
help="Enable torch.compile fullgraph mode (stricter but potentially faster).",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--fuse-qkv",
env_var="DYN_TRTLLM_FUSE_QKV",
default=True,
help="Enable QKV fusion for transformer attention layers.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-cuda-graph",
env_var="DYN_TRTLLM_ENABLE_CUDA_GRAPH",
default=False,
help="Enable CUDA graph capture for transformer forward passes. Mutually exclusive with torch.compile.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-layerwise-nvtx-marker",
env_var="DYN_TRTLLM_ENABLE_LAYERWISE_NVTX_MARKER",
default=False,
help="Enable per-layer NVTX markers for profiling with Nsight Systems.",
)
add_argument(
diffusion_group,
flag_name="--warmup-steps",
env_var="DYN_TRTLLM_WARMUP_STEPS",
default=1,
arg_type=int,
help="Number of denoising steps to run during warmup (0 to disable).",
)
add_argument(
diffusion_group,
flag_name="--dit-dp-size",
......@@ -348,6 +416,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="Enable async CPU offload for memory efficiency.",
)
add_argument(
diffusion_group,
flag_name="--skip-components",
env_var="DYN_TRTLLM_SKIP_COMPONENTS",
default="",
help=(
"Comma-separated list of pipeline components to skip loading. "
"Valid values: transformer, vae, text_encoder, tokenizer, scheduler, "
"image_encoder, image_processor."
),
)
class DynamoTrtllmConfig(ConfigBase):
......@@ -383,12 +462,21 @@ class DynamoTrtllmConfig(ConfigBase):
default_num_frames: int
default_num_inference_steps: int
default_guidance_scale: float
torch_dtype: str
revision: Optional[str] = None
enable_teacache: bool
teacache_use_ret_steps: bool
teacache_thresh: float
attn_type: str
linear_type: str
attn_backend: str
quant_algo: Optional[str]
quant_dynamic: bool
disable_torch_compile: bool
torch_compile_mode: str
enable_fullgraph: bool
fuse_qkv: bool
enable_cuda_graph: bool
enable_layerwise_nvtx_marker: bool
warmup_steps: int
dit_dp_size: int
dit_tp_size: int
dit_ulysses_size: int
......@@ -396,6 +484,7 @@ class DynamoTrtllmConfig(ConfigBase):
dit_cfg_size: int
dit_fsdp_size: int
enable_async_cpu_offload: bool
skip_components: str
def validate(self) -> None:
if isinstance(self.disaggregation_mode, str):
......
......@@ -5,9 +5,16 @@
This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers.
Fields map to TensorRT-LLM's DiffusionArgs sub-configs:
- PipelineConfig: torch_compile, CUDA graph, warmup, offloading, fuse_qkv
- AttentionConfig: attention backend (VANILLA, TRTLLM)
- ParallelConfig: dit_*_size parallelism dimensions
- TeaCacheConfig: caching optimization
- QuantConfig: quantization algorithm and dynamic flags
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from dynamo.common.utils.namespace import get_worker_namespace
......@@ -23,7 +30,7 @@ class DiffusionConfig:
"""Configuration for diffusion model workers (video/image generation).
This configuration is used by DiffusionEngine and diffusion handlers.
It can be populated from command-line arguments in trtllm_utils.py.
It can be populated from command-line arguments in backend_args.py.
"""
# Dynamo runtime config
......@@ -41,6 +48,8 @@ class DiffusionConfig:
# bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.)
# float16 can be used on older GPUs (V100, etc.)
torch_dtype: str = "bfloat16"
# HuggingFace Hub revision (branch, tag, or commit SHA) for model download.
revision: Optional[str] = None
# Media storage
media_output_fs_url: str = "file:///tmp/dynamo_media"
......@@ -58,16 +67,39 @@ class DiffusionConfig:
default_num_inference_steps: int = 50
default_guidance_scale: float = 5.0
# visual_gen optimization config
# ── Pipeline optimization config (maps to PipelineConfig) ──
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Enable torch.compile fullgraph mode (stricter but potentially faster)
enable_fullgraph: bool = False
# QKV fusion for transformer attention layers
fuse_qkv: bool = True
# CUDA graph capture for transformer forward passes
# (mutually exclusive with torch.compile — torch.compile takes priority)
enable_cuda_graph: bool = False
# Enable per-layer NVTX markers for profiling
enable_layerwise_nvtx_marker: bool = False
# Number of denoising steps to run during warmup (0 to disable)
warmup_steps: int = 1
# ── Attention config (maps to AttentionConfig) ──
# Attention backend: "VANILLA" (PyTorch SDPA) or "TRTLLM"
attn_backend: str = "VANILLA"
# ── Quantization config (maps to DiffusionArgs.quant_config) ──
# Quantization algorithm. Options:
# None (no quantization), "FP8", "FP8_BLOCK_SCALES", "NVFP4",
# "W4A16_AWQ", "W4A8_AWQ", "W8A8_SQ_PER_CHANNEL"
quant_algo: Optional[str] = None
# Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading)
quant_dynamic: bool = True
# ── TeaCache optimization config (maps to TeaCacheConfig) ──
enable_teacache: bool = False
teacache_use_ret_steps: bool = True
teacache_thresh: float = 0.2
attn_type: str = "default"
linear_type: str = "default"
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Parallelism config (DiTParallelConfig)
# ── Parallelism config (maps to ParallelConfig) ──
dit_dp_size: int = 1
dit_tp_size: int = 1
dit_ulysses_size: int = 1
......@@ -75,9 +107,14 @@ class DiffusionConfig:
dit_cfg_size: int = 1
dit_fsdp_size: int = 1
# CPU offload config
# ── Offloading config (maps to PipelineConfig) ──
enable_async_cpu_offload: bool = False
visual_gen_block_cpu_offload_stride: int = 1
# ── Component loading options ──
# Components to skip loading (e.g., ["text_encoder", "vae"]).
# Valid values: "transformer", "vae", "text_encoder", "tokenizer",
# "scheduler", "image_encoder", "image_processor"
skip_components: list[str] = field(default_factory=list)
def __str__(self) -> str:
return (
......@@ -93,8 +130,10 @@ class DiffusionConfig:
f"default_num_frames={self.default_num_frames}, "
f"default_num_inference_steps={self.default_num_inference_steps}, "
f"enable_teacache={self.enable_teacache}, "
f"attn_type={self.attn_type}, "
f"linear_type={self.linear_type}, "
f"attn_backend={self.attn_backend}, "
f"quant_algo={self.quant_algo}, "
f"enable_cuda_graph={self.enable_cuda_graph}, "
f"warmup_steps={self.warmup_steps}, "
f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})"
)
......@@ -4,7 +4,7 @@
"""Engine modules for TensorRT-LLM backend.
This module provides engine wrappers for various generative models:
- DiffusionEngine: Generic wrapper for visual_gen diffusion pipelines
- DiffusionEngine: Generic wrapper for TensorRT-LLM visual_gen diffusion pipelines
"""
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Generic Diffusion Engine wrapper for visual_gen pipelines.
"""Generic Diffusion Engine wrapper for TensorRT-LLM visual_gen pipelines.
This module provides a unified interface for various diffusion models
(Wan, Flux, Cosmos, etc.) through a pipeline registry system.
(Wan, Flux, Cosmos, etc.) through TensorRT-LLM's AutoPipeline system.
The pipeline type is auto-detected from model_index.json (shipped with every
HuggingFace Diffusers model), eliminating the need for a --model-type flag.
Requirements:
- visual_gen: Part of TensorRT-LLM, located at tensorrt_llm/visual_gen/.
Currently on the feat/visual_gen branch (not yet merged to main).
See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen
- tensorrt_llm with visual_gen support (tensorrt_llm._torch.visual_gen).
See: https://github.com/NVIDIA/TensorRT-LLM
- See docs/pages/backends/trtllm/README.md for setup instructions.
Note on imports:
visual_gen is imported lazily in initialize() because:
tensorrt_llm._torch.visual_gen is imported lazily in initialize() because:
1. It's a heavy package that may not be installed in all environments
2. Importing at module load would fail if visual_gen is not available
2. Importing at module load would fail if tensorrt_llm is not available
3. This allows the module to be imported for type checking and validation
without requiring visual_gen to be installed
without requiring tensorrt_llm to be installed
"""
import importlib
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import random
from enum import Enum
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
if TYPE_CHECKING:
from tensorrt_llm._torch.visual_gen import DiffusionArgs
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
logger = logging.getLogger(__name__)
@dataclass
class PipelineInfo:
"""Auto-detected pipeline information from model_index.json."""
class DiffusionModality(str, Enum):
"""Output modality of a diffusion pipeline."""
VIDEO = "video_diffusion"
IMAGE = "image_diffusion"
# Explicit mapping from TRT-LLM pipeline class names to their output modality.
# This replaces brittle substring matching and must be updated when new
# pipelines are registered in TRT-LLM's PIPELINE_REGISTRY.
_PIPELINE_MODALITY_MAP: dict[str, DiffusionModality] = {
"WanPipeline": DiffusionModality.VIDEO,
"WanImageToVideoPipeline": DiffusionModality.VIDEO,
"FluxPipeline": DiffusionModality.IMAGE,
"LTX2Pipeline": DiffusionModality.VIDEO,
}
module_path: str
class_name: str
modalities: list[str]
config_overrides: dict[str, Any]
# Default when the pipeline is not yet loaded or the class name is unknown.
_DEFAULT_MODALITY = DiffusionModality.VIDEO
class DiffusionEngine:
"""Generic wrapper for visual_gen diffusion pipelines.
"""Generic wrapper for TensorRT-LLM visual_gen diffusion pipelines.
This engine provides:
- Auto-detection of pipeline class from model_index.json
- A registry mapping diffusers _class_name to visual_gen pipelines
- Lazy loading of pipeline modules
- Common interface for video/image generation
- Auto-detection of pipeline class from model_index.json via AutoPipeline
- Loading and initialization through PipelineLoader
- Common interface for video/image generation via pipeline.infer()
The old visual_gen standalone package (setup_configs + from_pretrained +
PIPELINE_REGISTRY) has been replaced by TensorRT-LLM's integrated
visual_gen module which uses:
- DiffusionArgs for configuration
- PipelineLoader for model loading (handles MetaInit, weight loading,
quantization, torch.compile, and warmup)
- AutoPipeline for pipeline type auto-detection
- MediaOutput for typed output (video/image/audio torch tensors)
Example:
>>> engine = DiffusionEngine(config)
>>> await engine.initialize()
>>> frames = engine.generate(prompt="A cat playing piano", ...)
>>> output = engine.generate(prompt="A cat playing piano", ...)
>>> output.video # torch.Tensor (num_frames, H, W, 3) uint8
"""
# Registry: diffusers _class_name -> (module_path, visual_gen_class, supported_modalities)
# The _class_name comes from model_index.json shipped with every HF Diffusers model.
# torch_compile_models is derived dynamically from transformer* keys in model_index.json.
#
# NOTE: This registry is initially focused on Wan text-to-video models.
# Follow-up PRs will extend support for other model families (Flux, Cosmos, etc.)
# which may require additional config fields in DiffusionConfig.
PIPELINE_REGISTRY: dict[str, tuple[str, str, list[str]]] = {
"WanPipeline": (
"visual_gen.pipelines.wan_pipeline",
"ditWanPipeline",
["video_diffusion"],
),
# TODO: Add support for WanImageToVideoPipeline, FluxPipeline, etc.
}
@classmethod
def detect_pipeline_info(cls, model_path: str) -> PipelineInfo:
"""Auto-detect pipeline class from model's model_index.json.
Reads model_index.json (local path or HuggingFace Hub) to determine:
- Which visual_gen pipeline class to use (via _class_name)
- Which transformer models to torch.compile (via transformer* keys)
Args:
model_path: Local path or HuggingFace model identifier.
Returns:
PipelineInfo with module_path, class_name, modalities, and config_overrides.
Raises:
ValueError: If _class_name is not in the registry.
FileNotFoundError: If model_index.json cannot be found locally or on HF Hub.
"""
# Try local path first
local_index = Path(model_path) / "model_index.json"
if local_index.exists():
with open(local_index) as f:
model_index = json.load(f)
else:
# Download from HuggingFace Hub
from huggingface_hub import hf_hub_download
index_path = hf_hub_download(model_path, "model_index.json")
with open(index_path) as f:
model_index = json.load(f)
class_name = model_index.get("_class_name")
if class_name not in cls.PIPELINE_REGISTRY:
supported = list(cls.PIPELINE_REGISTRY.keys())
raise ValueError(
f"Unsupported diffusion pipeline '{class_name}' from model '{model_path}'.\n"
f"Supported pipelines: {', '.join(supported)}\n"
f"Check that model_index.json has a supported _class_name."
)
module_path, vg_class, modalities = cls.PIPELINE_REGISTRY[class_name]
# Derive torch_compile_models from transformer* keys in model_index.json
transformer_keys = sorted(k for k in model_index if k.startswith("transformer"))
torch_compile_models = (
",".join(transformer_keys) if transformer_keys else "transformer"
)
config_overrides = {"torch_compile_models": torch_compile_models}
return PipelineInfo(module_path, vg_class, modalities, config_overrides)
def __init__(self, config: "DiffusionConfig"):
"""Initialize the engine with configuration.
Auto-detects the pipeline type from config.model_path's model_index.json.
Args:
config: Diffusion generation configuration.
Raises:
ValueError: If the model's pipeline type is not supported.
"""
info = self.detect_pipeline_info(config.model_path)
self.config = config
self._pipeline = None
self._pipeline: Optional["BasePipeline"] = None
self._initialized = False
self._module_path = info.module_path
self._class_name = info.class_name
self._supported_modalities = info.modalities
self._config_overrides = info.config_overrides
async def initialize(self) -> None:
"""Load and configure the diffusion pipeline.
"""Load and configure the diffusion pipeline via PipelineLoader.
This is called once at worker startup to load the model.
The specific pipeline class is determined by the auto-detected pipeline type.
PipelineLoader handles:
1. Loading config via DiffusionModelConfig.from_pretrained()
2. Creating pipeline via AutoPipeline.from_config() (auto-detects type)
3. Loading weights with optional on-the-fly quantization
4. Post-load hooks (TeaCache setup, etc.)
5. torch.compile (if enabled)
6. Warmup inference
"""
if self._initialized:
logger.warning("Engine already initialized, skipping")
return
logger.info(
f"Initializing DiffusionEngine: pipeline={self._class_name}, "
f"model_path={self.config.model_path}"
f"Initializing DiffusionEngine: model_path={self.config.model_path}"
)
# Import visual_gen setup
from visual_gen import setup_configs
# Import TensorRT-LLM visual_gen components
from tensorrt_llm._torch.visual_gen import PipelineLoader
# Build configuration dict based on model type
dit_configs = self._build_dit_configs()
logger.info(f"dit_configs: {dit_configs}")
# Build DiffusionArgs from DiffusionConfig
diffusion_args = self._build_diffusion_args()
logger.info(f"DiffusionArgs: {diffusion_args}")
# Setup global configuration (required before pipeline loading)
setup_configs(**dit_configs)
# Use PipelineLoader for the full loading flow:
# DiffusionArgs → DiffusionModelConfig → AutoPipeline → BasePipeline
loader = PipelineLoader(diffusion_args)
self._pipeline = loader.load()
# Dynamically import the pipeline class
logger.info(f"Importing pipeline from {self._module_path}.{self._class_name}")
module = importlib.import_module(self._module_path)
pipeline_class = getattr(module, self._class_name)
# Load the pipeline
# Convert torch_dtype string to actual torch dtype
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(self.config.torch_dtype, torch.bfloat16)
self._initialized = True
logger.info(
f"Loading pipeline from {self.config.model_path} with dtype={self.config.torch_dtype}"
f"DiffusionEngine initialization complete: "
f"{self._pipeline.__class__.__name__}"
)
self._pipeline = pipeline_class.from_pretrained(
self.config.model_path,
torch_dtype=torch_dtype,
**dit_configs,
)
# Move to target device
# NOTE: HuggingFace's from_pretrained() loads to CPU by default,
# so we must explicitly move to GPU for optimal performance.
if self.device == "cuda":
logger.info("Moving pipeline to GPU...")
self._pipeline.to(self.device)
logger.info("Pipeline moved to GPU successfully")
else:
logger.info("CPU offload enabled, pipeline stays on CPU")
self._initialized = True
logger.info(f"DiffusionEngine initialization complete: {self._class_name}")
def _build_diffusion_args(self) -> "DiffusionArgs":
"""Build DiffusionArgs from DiffusionConfig.
def _build_dit_configs(self) -> dict[str, Any]:
"""Build dit_configs dict from DiffusionConfig.
Maps dynamo's DiffusionConfig fields to TensorRT-LLM's DiffusionArgs
structure with its nested sub-configs (PipelineConfig, AttentionConfig,
ParallelConfig, TeaCacheConfig, quant_config).
Returns:
Configuration dictionary for visual_gen's setup_configs.
DiffusionArgs instance for PipelineLoader.
"""
# Get torch_compile_models from auto-detected config overrides
# Each pipeline in PIPELINE_REGISTRY specifies its required settings
torch_compile_models = self._config_overrides.get(
"torch_compile_models", "transformer"
from tensorrt_llm._torch.visual_gen import (
DiffusionArgs,
ParallelConfig,
PipelineConfig,
TeaCacheConfig,
)
return {
"pipeline": {
"enable_torch_compile": not self.config.disable_torch_compile,
"torch_compile_models": torch_compile_models,
"torch_compile_mode": self.config.torch_compile_mode,
"fuse_qkv": True,
},
"attn": {
"type": self.config.attn_type,
},
"linear": {
"type": self.config.linear_type,
"recipe": "dynamic",
},
"parallel": {
"disable_parallel_vae": False,
"parallel_vae_split_dim": "width",
"dit_dp_size": self.config.dit_dp_size,
"dit_tp_size": self.config.dit_tp_size,
"dit_ulysses_size": self.config.dit_ulysses_size,
"dit_ring_size": self.config.dit_ring_size,
"dit_cp_size": 1,
"dit_cfg_size": self.config.dit_cfg_size,
"dit_fsdp_size": self.config.dit_fsdp_size,
"t5_fsdp_size": 1,
},
"teacache": {
"enable_teacache": self.config.enable_teacache,
"use_ret_steps": self.config.teacache_use_ret_steps,
"teacache_thresh": self.config.teacache_thresh,
"ret_steps": 0,
"cutoff_steps": self.config.default_num_inference_steps,
},
from tensorrt_llm._torch.visual_gen.config import AttentionConfig
# Build quant_config dict if quantization is requested
# DiffusionArgs accepts a dict in ModelOpt format and parses it via model_validator
quant_config: dict | None = None
if self.config.quant_algo:
quant_config = {
"quant_algo": self.config.quant_algo,
"dynamic": self.config.quant_dynamic,
}
args_kwargs: dict = dict(
checkpoint_path=self.config.model_path,
device=self.device,
dtype=self.config.torch_dtype,
skip_components=self.config.skip_components,
pipeline=PipelineConfig(
enable_torch_compile=not self.config.disable_torch_compile,
torch_compile_mode=self.config.torch_compile_mode,
enable_fullgraph=self.config.enable_fullgraph,
fuse_qkv=self.config.fuse_qkv,
enable_cuda_graph=self.config.enable_cuda_graph,
enable_layerwise_nvtx_marker=self.config.enable_layerwise_nvtx_marker,
warmup_steps=self.config.warmup_steps,
enable_offloading=self.config.enable_async_cpu_offload,
),
attention=AttentionConfig(
backend=self.config.attn_backend.upper(),
),
parallel=ParallelConfig(
dit_dp_size=self.config.dit_dp_size,
dit_tp_size=self.config.dit_tp_size,
dit_ulysses_size=self.config.dit_ulysses_size,
dit_ring_size=self.config.dit_ring_size,
dit_cfg_size=self.config.dit_cfg_size,
dit_fsdp_size=self.config.dit_fsdp_size,
),
teacache=TeaCacheConfig(
enable_teacache=self.config.enable_teacache,
use_ret_steps=self.config.teacache_use_ret_steps,
teacache_thresh=self.config.teacache_thresh,
),
)
# Add optional fields
if self.config.revision:
args_kwargs["revision"] = self.config.revision
if quant_config is not None:
args_kwargs["quant_config"] = quant_config
return DiffusionArgs(**args_kwargs)
def generate(
self,
prompt: str,
......@@ -271,12 +210,15 @@ class DiffusionEngine:
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
seed: Optional[int] = None,
) -> np.ndarray:
) -> "MediaOutput":
"""Generate video/image frames from text prompt.
This is a synchronous method that should be called from a thread pool
to avoid blocking the event loop.
The pipeline's infer() method handles the full generation flow:
prompt encoding, latent preparation, denoising loop, and VAE decoding.
Args:
prompt: Text description of the content to generate.
negative_prompt: Text to avoid in the generation.
......@@ -288,8 +230,10 @@ class DiffusionEngine:
seed: Random seed for reproducibility.
Returns:
numpy array of shape (num_frames, height, width, 3) with uint8 values
for video, or (height, width, 3) for images.
MediaOutput with model-specific fields populated:
- .video: torch.Tensor (num_frames, H, W, 3) uint8 for video models
- .image: torch.Tensor (H, W, 3) uint8 for image models
- .audio: torch.Tensor for audio (if supported by model)
Raises:
RuntimeError: If engine not initialized or generation fails.
......@@ -302,15 +246,14 @@ class DiffusionEngine:
f"size={width}x{height}, frames={num_frames}, steps={num_inference_steps}"
)
# Create generator for reproducibility
# Device must match pipeline device (CPU if offload enabled, CUDA otherwise)
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
# Use TRT-LLM's DiffusionRequest dataclass so that all defaults
# (including pipeline-specific fields like max_sequence_length,
# guidance_scale_2, boundary_ratio) are owned by TRT-LLM rather
# than hardcoded here.
from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest
# Run the pipeline
with torch.no_grad():
result = self._pipeline(
req = DiffusionRequest(
request_id=0,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
......@@ -318,25 +261,31 @@ class DiffusionEngine:
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="np", # Return numpy array
seed=seed if seed is not None else random.randint(0, 2**32 - 1),
)
# result.frames[0] is numpy array (num_frames, height, width, 3) uint8
frames = result.frames[0]
logger.info(f"Generated output with shape {frames.shape}")
# Run the pipeline — infer() wraps forward() with torch.no_grad()
output = self._pipeline.infer(req)
if output is not None:
if output.video is not None:
logger.info(f"Generated video output with shape {output.video.shape}")
elif output.image is not None:
logger.info(f"Generated image output with shape {output.image.shape}")
return frames
return output
def cleanup(self) -> None:
"""Cleanup resources."""
if self._pipeline is not None:
if hasattr(self._pipeline, "cleanup"):
self._pipeline.cleanup()
del self._pipeline
self._pipeline = None
self._initialized = False
if self.device == "cuda":
torch.cuda.empty_cache()
logger.info(f"DiffusionEngine cleanup complete: {self._class_name}")
logger.info("DiffusionEngine cleanup complete")
@property
def is_initialized(self) -> bool:
......@@ -345,8 +294,26 @@ class DiffusionEngine:
@property
def supported_modalities(self) -> list[str]:
"""Get the modalities supported by this engine's model type."""
return self._supported_modalities
"""Get the modalities supported by this engine's pipeline.
Uses an explicit mapping from pipeline class names to modalities
(see ``_PIPELINE_MODALITY_MAP``). The pipeline class is determined
at load time by AutoPipeline from model_index.json.
"""
if self._pipeline is None:
return [_DEFAULT_MODALITY.value]
class_name = self._pipeline.__class__.__name__
modality = _PIPELINE_MODALITY_MAP.get(class_name)
if modality is None:
logger.warning(
"Unknown pipeline class '%s' — defaulting to %s. "
"Please add it to _PIPELINE_MODALITY_MAP.",
class_name,
_DEFAULT_MODALITY.value,
)
modality = _DEFAULT_MODALITY
return [modality.value]
@property
def device(self) -> str:
......
......@@ -4,6 +4,8 @@
"""Video generation request handler for TensorRT-LLM backend.
This handler processes video generation requests using diffusion models.
It handles MediaOutput from TensorRT-LLM's visual_gen pipelines, which
can contain video, image, and/or audio tensors depending on the model.
"""
import asyncio
......@@ -32,9 +34,14 @@ logger = logging.getLogger(__name__)
class VideoGenerationHandler(BaseGenerativeHandler):
"""Handler for video generation requests.
This handler receives video generation requests, runs the diffusion
pipeline via DiffusionEngine, encodes the output to MP4, and returns
the video URL or base64-encoded data.
This handler receives generation requests, runs the diffusion pipeline
via DiffusionEngine, encodes the output to the appropriate media format,
and returns the media URL or base64-encoded data.
Supports MediaOutput with:
- video: torch.Tensor (num_frames, H, W, 3) uint8 → encoded as MP4
- image: logged as unsupported (use an image handler instead)
- audio: logged (future: mux into MP4)
Inherits from BaseGenerativeHandler to share the common interface with
LLM handlers.
......@@ -59,8 +66,8 @@ class VideoGenerationHandler(BaseGenerativeHandler):
)
self.media_output_fs = get_fs(config.media_output_fs_url)
self.media_output_http_url = config.media_output_http_url
# Serialize pipeline access — visual_gen is not thread-safe (global
# singleton configs, mutable instance state, unprotected CUDA graph cache).
# Serialize pipeline access — the diffusion pipeline is not thread-safe
# (mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event
# loop stays free for health checks and signal handling.
self._generate_lock = asyncio.Lock()
......@@ -159,21 +166,26 @@ class VideoGenerationHandler(BaseGenerativeHandler):
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""Generate video from request.
"""Generate video/image from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Handles MediaOutput from the pipeline:
- video tensor → MP4
- image tensor → unsupported (raises error)
- audio tensor → unsupported (raises error)
Args:
request: Request dictionary with video generation parameters.
request: Request dictionary with generation parameters.
context: Dynamo context for request tracking.
Yields:
Response dictionary with generated video data.
Response dictionary with generated media data.
"""
start_time = time.time()
request_id = str(uuid.uuid4())
logger.info(f"Received video generation request: {request_id}")
logger.info(f"Received generation request: {request_id}")
try:
# Parse request
......@@ -202,11 +214,11 @@ class VideoGenerationHandler(BaseGenerativeHandler):
# Run generation in thread pool (blocking operation).
# Lock ensures only one request uses the pipeline at a time.
# TODO: Add cancellation support. This requires:
# 1. visual_gen to expose a cancellation hook in the denoising loop
# 1. The pipeline to expose a cancellation hook in the denoising loop
# 2. Passing a cancellation token/event to engine.generate()
# 3. Checking context.cancelled() and propagating to the pipeline
async with self._generate_lock:
frames = await asyncio.to_thread(
output = await asyncio.to_thread(
self.engine.generate,
prompt=req.prompt,
negative_prompt=nvext.negative_prompt,
......@@ -218,15 +230,47 @@ class VideoGenerationHandler(BaseGenerativeHandler):
seed=nvext.seed,
)
if output is None:
raise RuntimeError("Pipeline returned no output (MediaOutput is None)")
# Determine output format
response_format = req.response_format or "url"
fps = nvext.fps or self.config.default_fps
# Encode frames to MP4 bytes in memory
video_bytes = await asyncio.to_thread(encode_to_mp4_bytes, frames, fps=fps)
# Encode media based on what the pipeline returned
if output.video is not None:
# Video output: torch.Tensor (num_frames, H, W, 3) uint8 → MP4
frames_np = output.video.cpu().numpy()
logger.info(
f"Request {request_id}: encoding video output "
f"(shape={frames_np.shape}) to MP4 at {fps} fps"
)
video_bytes = await asyncio.to_thread(
encode_to_mp4_bytes, frames_np, fps=fps
)
elif output.image is not None:
raise RuntimeError(
"Pipeline returned image-only output, but this handler "
"only supports video. Use an image generation handler instead."
)
# Log audio if present (unsupported)
elif output.audio is not None:
raise RuntimeError(
"Pipeline returned audio-only output, but this handler "
"only supports video. Use an audio generation handler instead."
)
else:
raise RuntimeError(
"Pipeline returned MediaOutput with no video or image or audio data. "
f"MediaOutput fields: video={output.video is not None}, "
f"image={output.image is not None}, audio={output.audio is not None}"
)
# Return media via URL or base64
if response_format == "url":
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs(
self.media_output_fs,
......@@ -236,7 +280,6 @@ class VideoGenerationHandler(BaseGenerativeHandler):
)
video_data = VideoData(url=video_url)
else:
# Encode to base64
b64_video = base64.b64encode(video_bytes).decode("utf-8")
video_data = VideoData(b64_json=b64_video)
......
......@@ -13,10 +13,12 @@ import asyncio
import threading
import time
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import torch
from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest,
......@@ -104,8 +106,11 @@ class TestDiffusionConfig:
# Optimization defaults
assert config.enable_teacache is False
assert config.attn_type == "default"
assert config.linear_type == "default"
assert config.attn_backend == "VANILLA"
assert config.quant_algo is None
assert config.enable_cuda_graph is False
assert config.warmup_steps == 1
assert config.fuse_qkv is True
# Parallelism defaults
assert config.dit_dp_size == 1
......@@ -532,10 +537,12 @@ class ConcurrencyTracker:
with self._lock:
self._active_count -= 1
# Return fake frames (shape: [num_frames, H, W, C])
import numpy as np
return np.zeros((4, 64, 64, 3), dtype=np.uint8)
# Return a mock MediaOutput with a video tensor
return SimpleNamespace(
video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
image=None,
audio=None,
)
class TestVideoHandlerConcurrency:
......@@ -660,16 +667,17 @@ class TestVideoHandlerResponseFormats:
def _make_handler(self):
"""Create a handler with mocked engine and fs."""
import numpy as np
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
mock_engine = MagicMock()
mock_engine.generate = MagicMock(
return_value=np.zeros((4, 64, 64, 3), dtype=np.uint8)
mock_output = SimpleNamespace(
video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
image=None,
audio=None,
)
mock_engine = MagicMock()
mock_engine.generate = MagicMock(return_value=mock_output)
config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media",
......
......@@ -33,20 +33,19 @@ async def init_video_diffusion_worker(
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
"""
# Check visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen
# branch — not yet in any release. Without this check, users would get a cryptic
# ImportError deep inside DiffusionEngine.initialize().
# Check tensorrt_llm visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM (tensorrt_llm._torch.visual_gen).
# Without this check, users would get a cryptic ImportError deep inside
# DiffusionEngine.initialize().
try:
import visual_gen # noqa: F401
import tensorrt_llm._torch.visual_gen # noqa: F401
except ImportError:
raise ImportError(
"Video diffusion requires the 'visual_gen' package from TensorRT-LLM's "
"feat/visual_gen branch. Install with:\n"
" git clone https://github.com/NVIDIA/TensorRT-LLM.git\n"
" cd TensorRT-LLM && git checkout feat/visual_gen\n"
" cd tensorrt_llm/visual_gen && pip install -e .\n"
"See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen"
"Video diffusion requires TensorRT-LLM with visual_gen support.\n"
"The visual_gen module is at tensorrt_llm._torch.visual_gen.\n"
"Install TensorRT-LLM with AIGV support:\n"
" pip install tensorrt_llm\n"
"See: https://github.com/NVIDIA/TensorRT-LLM"
) from None
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
......@@ -55,6 +54,13 @@ async def init_video_diffusion_worker(
logging.info(f"Initializing video diffusion worker with config: {config}")
# Parse skip_components from comma-separated string to list
skip_components = (
[c.strip() for c in config.skip_components.split(",") if c.strip()]
if config.skip_components
else []
)
# Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig(
namespace=config.namespace,
......@@ -65,6 +71,8 @@ async def init_video_diffusion_worker(
event_plane=config.event_plane,
model_path=config.model,
served_model_name=config.served_model_name,
torch_dtype=config.torch_dtype,
revision=config.revision,
media_output_fs_url=config.media_output_fs_url,
media_output_http_url=config.media_output_http_url,
default_height=config.default_height,
......@@ -72,19 +80,34 @@ async def init_video_diffusion_worker(
default_num_frames=config.default_num_frames,
default_num_inference_steps=config.default_num_inference_steps,
default_guidance_scale=config.default_guidance_scale,
enable_teacache=config.enable_teacache,
teacache_thresh=config.teacache_thresh,
attn_type=config.attn_type,
linear_type=config.linear_type,
# Pipeline optimization
disable_torch_compile=config.disable_torch_compile,
torch_compile_mode=config.torch_compile_mode,
enable_fullgraph=config.enable_fullgraph,
fuse_qkv=config.fuse_qkv,
enable_cuda_graph=config.enable_cuda_graph,
enable_layerwise_nvtx_marker=config.enable_layerwise_nvtx_marker,
warmup_steps=config.warmup_steps,
# Attention
attn_backend=config.attn_backend,
# Quantization
quant_algo=config.quant_algo,
quant_dynamic=config.quant_dynamic,
# TeaCache
enable_teacache=config.enable_teacache,
teacache_use_ret_steps=config.teacache_use_ret_steps,
teacache_thresh=config.teacache_thresh,
# Parallelism
dit_dp_size=config.dit_dp_size,
dit_tp_size=config.dit_tp_size,
dit_ulysses_size=config.dit_ulysses_size,
dit_ring_size=config.dit_ring_size,
dit_cfg_size=config.dit_cfg_size,
dit_fsdp_size=config.dit_fsdp_size,
# Offloading
enable_async_cpu_offload=config.enable_async_cpu_offload,
# Component loading
skip_components=skip_components,
)
# Get the endpoint from the runtime
......
......@@ -216,11 +216,10 @@ Dynamo supports video generation using diffusion models through the `--modality
### Requirements
- **visual_gen**: Part of TensorRT-LLM, located at `tensorrt_llm/visual_gen/`. Currently available **only** on the [`feat/visual_gen`](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch (not yet merged to main or any release). Install from source:
- **TensorRT-LLM with visual_gen**: The `visual_gen` module is part of TensorRT-LLM (`tensorrt_llm._torch.visual_gen`). Install TensorRT-LLM following the [official instructions](https://github.com/NVIDIA/TensorRT-LLM#installation).
- **imageio with ffmpeg**: Required for encoding generated frames to MP4 video:
```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM && git checkout feat/visual_gen
cd tensorrt_llm/visual_gen && pip install -e .
pip install imageio[ffmpeg]
```
- **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment