Unverified Commit 8de1c3e0 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Update TRTLLM video diffusion for `visual_gen` module restructuring (PR #11462) (#6516)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent 78436fbf
......@@ -242,6 +242,22 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=float,
help="Default CFG guidance scale.",
)
add_argument(
diffusion_group,
flag_name="--torch-dtype",
env_var="DYN_TRTLLM_TORCH_DTYPE",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Torch dtype for model loading. bfloat16 recommended for Ampere+ GPUs.",
)
add_argument(
diffusion_group,
flag_name="--revision",
env_var="DYN_TRTLLM_REVISION",
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA) for model download.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-teacache",
......@@ -249,6 +265,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="Enable TeaCache optimization for faster generation.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--teacache-use-ret-steps",
env_var="DYN_TRTLLM_TEACACHE_USE_RET_STEPS",
default=True,
help="Use retention steps for TeaCache.",
)
add_argument(
diffusion_group,
flag_name="--teacache-thresh",
......@@ -259,24 +282,33 @@ class DynamoTrtllmArgGroup(ArgGroup):
)
add_argument(
diffusion_group,
flag_name="--attn-type",
env_var="DYN_TRTLLM_ATTN_TYPE",
default="default",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"],
help="Attention type for diffusion models.",
flag_name="--attn-backend",
env_var="DYN_TRTLLM_ATTN_BACKEND",
default="VANILLA",
choices=["VANILLA", "TRTLLM"],
help="Attention backend for diffusion models. VANILLA = PyTorch SDPA, TRTLLM = TensorRT-LLM kernels.",
)
add_argument(
diffusion_group,
flag_name="--linear-type",
env_var="DYN_TRTLLM_LINEAR_TYPE",
default="default",
flag_name="--quant-algo",
env_var="DYN_TRTLLM_QUANT_ALGO",
default=None,
choices=[
"default",
"trtllm-fp8-blockwise",
"trtllm-fp8-per-tensor",
"trtllm-nvfp4",
"FP8",
"FP8_BLOCK_SCALES",
"NVFP4",
"W4A16_AWQ",
"W4A8_AWQ",
"W8A8_SQ_PER_CHANNEL",
],
help="Linear type for quantization.",
help="Quantization algorithm for diffusion models. BF16 weights are quantized on-the-fly during loading.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--quant-dynamic",
env_var="DYN_TRTLLM_QUANT_DYNAMIC",
default=True,
help="Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading).",
)
add_negatable_bool_argument(
diffusion_group,
......@@ -293,6 +325,42 @@ class DynamoTrtllmArgGroup(ArgGroup):
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-fullgraph",
env_var="DYN_TRTLLM_ENABLE_FULLGRAPH",
default=False,
help="Enable torch.compile fullgraph mode (stricter but potentially faster).",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--fuse-qkv",
env_var="DYN_TRTLLM_FUSE_QKV",
default=True,
help="Enable QKV fusion for transformer attention layers.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-cuda-graph",
env_var="DYN_TRTLLM_ENABLE_CUDA_GRAPH",
default=False,
help="Enable CUDA graph capture for transformer forward passes. Mutually exclusive with torch.compile.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-layerwise-nvtx-marker",
env_var="DYN_TRTLLM_ENABLE_LAYERWISE_NVTX_MARKER",
default=False,
help="Enable per-layer NVTX markers for profiling with Nsight Systems.",
)
add_argument(
diffusion_group,
flag_name="--warmup-steps",
env_var="DYN_TRTLLM_WARMUP_STEPS",
default=1,
arg_type=int,
help="Number of denoising steps to run during warmup (0 to disable).",
)
add_argument(
diffusion_group,
flag_name="--dit-dp-size",
......@@ -348,6 +416,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False,
help="Enable async CPU offload for memory efficiency.",
)
add_argument(
diffusion_group,
flag_name="--skip-components",
env_var="DYN_TRTLLM_SKIP_COMPONENTS",
default="",
help=(
"Comma-separated list of pipeline components to skip loading. "
"Valid values: transformer, vae, text_encoder, tokenizer, scheduler, "
"image_encoder, image_processor."
),
)
class DynamoTrtllmConfig(ConfigBase):
......@@ -383,12 +462,21 @@ class DynamoTrtllmConfig(ConfigBase):
default_num_frames: int
default_num_inference_steps: int
default_guidance_scale: float
torch_dtype: str
revision: Optional[str] = None
enable_teacache: bool
teacache_use_ret_steps: bool
teacache_thresh: float
attn_type: str
linear_type: str
attn_backend: str
quant_algo: Optional[str]
quant_dynamic: bool
disable_torch_compile: bool
torch_compile_mode: str
enable_fullgraph: bool
fuse_qkv: bool
enable_cuda_graph: bool
enable_layerwise_nvtx_marker: bool
warmup_steps: int
dit_dp_size: int
dit_tp_size: int
dit_ulysses_size: int
......@@ -396,6 +484,7 @@ class DynamoTrtllmConfig(ConfigBase):
dit_cfg_size: int
dit_fsdp_size: int
enable_async_cpu_offload: bool
skip_components: str
def validate(self) -> None:
if isinstance(self.disaggregation_mode, str):
......
......@@ -5,9 +5,16 @@
This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers.
Fields map to TensorRT-LLM's DiffusionArgs sub-configs:
- PipelineConfig: torch_compile, CUDA graph, warmup, offloading, fuse_qkv
- AttentionConfig: attention backend (VANILLA, TRTLLM)
- ParallelConfig: dit_*_size parallelism dimensions
- TeaCacheConfig: caching optimization
- QuantConfig: quantization algorithm and dynamic flags
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from dynamo.common.utils.namespace import get_worker_namespace
......@@ -23,7 +30,7 @@ class DiffusionConfig:
"""Configuration for diffusion model workers (video/image generation).
This configuration is used by DiffusionEngine and diffusion handlers.
It can be populated from command-line arguments in trtllm_utils.py.
It can be populated from command-line arguments in backend_args.py.
"""
# Dynamo runtime config
......@@ -41,6 +48,8 @@ class DiffusionConfig:
# bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.)
# float16 can be used on older GPUs (V100, etc.)
torch_dtype: str = "bfloat16"
# HuggingFace Hub revision (branch, tag, or commit SHA) for model download.
revision: Optional[str] = None
# Media storage
media_output_fs_url: str = "file:///tmp/dynamo_media"
......@@ -58,16 +67,39 @@ class DiffusionConfig:
default_num_inference_steps: int = 50
default_guidance_scale: float = 5.0
# visual_gen optimization config
# ── Pipeline optimization config (maps to PipelineConfig) ──
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Enable torch.compile fullgraph mode (stricter but potentially faster)
enable_fullgraph: bool = False
# QKV fusion for transformer attention layers
fuse_qkv: bool = True
# CUDA graph capture for transformer forward passes
# (mutually exclusive with torch.compile — torch.compile takes priority)
enable_cuda_graph: bool = False
# Enable per-layer NVTX markers for profiling
enable_layerwise_nvtx_marker: bool = False
# Number of denoising steps to run during warmup (0 to disable)
warmup_steps: int = 1
# ── Attention config (maps to AttentionConfig) ──
# Attention backend: "VANILLA" (PyTorch SDPA) or "TRTLLM"
attn_backend: str = "VANILLA"
# ── Quantization config (maps to DiffusionArgs.quant_config) ──
# Quantization algorithm. Options:
# None (no quantization), "FP8", "FP8_BLOCK_SCALES", "NVFP4",
# "W4A16_AWQ", "W4A8_AWQ", "W8A8_SQ_PER_CHANNEL"
quant_algo: Optional[str] = None
# Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading)
quant_dynamic: bool = True
# ── TeaCache optimization config (maps to TeaCacheConfig) ──
enable_teacache: bool = False
teacache_use_ret_steps: bool = True
teacache_thresh: float = 0.2
attn_type: str = "default"
linear_type: str = "default"
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Parallelism config (DiTParallelConfig)
# ── Parallelism config (maps to ParallelConfig) ──
dit_dp_size: int = 1
dit_tp_size: int = 1
dit_ulysses_size: int = 1
......@@ -75,9 +107,14 @@ class DiffusionConfig:
dit_cfg_size: int = 1
dit_fsdp_size: int = 1
# CPU offload config
# ── Offloading config (maps to PipelineConfig) ──
enable_async_cpu_offload: bool = False
visual_gen_block_cpu_offload_stride: int = 1
# ── Component loading options ──
# Components to skip loading (e.g., ["text_encoder", "vae"]).
# Valid values: "transformer", "vae", "text_encoder", "tokenizer",
# "scheduler", "image_encoder", "image_processor"
skip_components: list[str] = field(default_factory=list)
def __str__(self) -> str:
return (
......@@ -93,8 +130,10 @@ class DiffusionConfig:
f"default_num_frames={self.default_num_frames}, "
f"default_num_inference_steps={self.default_num_inference_steps}, "
f"enable_teacache={self.enable_teacache}, "
f"attn_type={self.attn_type}, "
f"linear_type={self.linear_type}, "
f"attn_backend={self.attn_backend}, "
f"quant_algo={self.quant_algo}, "
f"enable_cuda_graph={self.enable_cuda_graph}, "
f"warmup_steps={self.warmup_steps}, "
f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})"
)
......@@ -4,7 +4,7 @@
"""Engine modules for TensorRT-LLM backend.
This module provides engine wrappers for various generative models:
- DiffusionEngine: Generic wrapper for visual_gen diffusion pipelines
- DiffusionEngine: Generic wrapper for TensorRT-LLM visual_gen diffusion pipelines
"""
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
......
......@@ -4,6 +4,8 @@
"""Video generation request handler for TensorRT-LLM backend.
This handler processes video generation requests using diffusion models.
It handles MediaOutput from TensorRT-LLM's visual_gen pipelines, which
can contain video, image, and/or audio tensors depending on the model.
"""
import asyncio
......@@ -32,9 +34,14 @@ logger = logging.getLogger(__name__)
class VideoGenerationHandler(BaseGenerativeHandler):
"""Handler for video generation requests.
This handler receives video generation requests, runs the diffusion
pipeline via DiffusionEngine, encodes the output to MP4, and returns
the video URL or base64-encoded data.
This handler receives generation requests, runs the diffusion pipeline
via DiffusionEngine, encodes the output to the appropriate media format,
and returns the media URL or base64-encoded data.
Supports MediaOutput with:
- video: torch.Tensor (num_frames, H, W, 3) uint8 → encoded as MP4
- image: logged as unsupported (use an image handler instead)
- audio: logged (future: mux into MP4)
Inherits from BaseGenerativeHandler to share the common interface with
LLM handlers.
......@@ -59,8 +66,8 @@ class VideoGenerationHandler(BaseGenerativeHandler):
)
self.media_output_fs = get_fs(config.media_output_fs_url)
self.media_output_http_url = config.media_output_http_url
# Serialize pipeline access — visual_gen is not thread-safe (global
# singleton configs, mutable instance state, unprotected CUDA graph cache).
# Serialize pipeline access — the diffusion pipeline is not thread-safe
# (mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event
# loop stays free for health checks and signal handling.
self._generate_lock = asyncio.Lock()
......@@ -159,21 +166,26 @@ class VideoGenerationHandler(BaseGenerativeHandler):
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""Generate video from request.
"""Generate video/image from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Handles MediaOutput from the pipeline:
- video tensor → MP4
- image tensor → unsupported (raises error)
- audio tensor → unsupported (raises error)
Args:
request: Request dictionary with video generation parameters.
request: Request dictionary with generation parameters.
context: Dynamo context for request tracking.
Yields:
Response dictionary with generated video data.
Response dictionary with generated media data.
"""
start_time = time.time()
request_id = str(uuid.uuid4())
logger.info(f"Received video generation request: {request_id}")
logger.info(f"Received generation request: {request_id}")
try:
# Parse request
......@@ -202,11 +214,11 @@ class VideoGenerationHandler(BaseGenerativeHandler):
# Run generation in thread pool (blocking operation).
# Lock ensures only one request uses the pipeline at a time.
# TODO: Add cancellation support. This requires:
# 1. visual_gen to expose a cancellation hook in the denoising loop
# 1. The pipeline to expose a cancellation hook in the denoising loop
# 2. Passing a cancellation token/event to engine.generate()
# 3. Checking context.cancelled() and propagating to the pipeline
async with self._generate_lock:
frames = await asyncio.to_thread(
output = await asyncio.to_thread(
self.engine.generate,
prompt=req.prompt,
negative_prompt=nvext.negative_prompt,
......@@ -218,15 +230,47 @@ class VideoGenerationHandler(BaseGenerativeHandler):
seed=nvext.seed,
)
if output is None:
raise RuntimeError("Pipeline returned no output (MediaOutput is None)")
# Determine output format
response_format = req.response_format or "url"
fps = nvext.fps or self.config.default_fps
# Encode frames to MP4 bytes in memory
video_bytes = await asyncio.to_thread(encode_to_mp4_bytes, frames, fps=fps)
# Encode media based on what the pipeline returned
if output.video is not None:
# Video output: torch.Tensor (num_frames, H, W, 3) uint8 → MP4
frames_np = output.video.cpu().numpy()
logger.info(
f"Request {request_id}: encoding video output "
f"(shape={frames_np.shape}) to MP4 at {fps} fps"
)
video_bytes = await asyncio.to_thread(
encode_to_mp4_bytes, frames_np, fps=fps
)
elif output.image is not None:
raise RuntimeError(
"Pipeline returned image-only output, but this handler "
"only supports video. Use an image generation handler instead."
)
# Log audio if present (unsupported)
elif output.audio is not None:
raise RuntimeError(
"Pipeline returned audio-only output, but this handler "
"only supports video. Use an audio generation handler instead."
)
else:
raise RuntimeError(
"Pipeline returned MediaOutput with no video or image or audio data. "
f"MediaOutput fields: video={output.video is not None}, "
f"image={output.image is not None}, audio={output.audio is not None}"
)
# Return media via URL or base64
if response_format == "url":
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs(
self.media_output_fs,
......@@ -236,7 +280,6 @@ class VideoGenerationHandler(BaseGenerativeHandler):
)
video_data = VideoData(url=video_url)
else:
# Encode to base64
b64_video = base64.b64encode(video_bytes).decode("utf-8")
video_data = VideoData(b64_json=b64_video)
......
......@@ -13,10 +13,12 @@ import asyncio
import threading
import time
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import torch
from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest,
......@@ -104,8 +106,11 @@ class TestDiffusionConfig:
# Optimization defaults
assert config.enable_teacache is False
assert config.attn_type == "default"
assert config.linear_type == "default"
assert config.attn_backend == "VANILLA"
assert config.quant_algo is None
assert config.enable_cuda_graph is False
assert config.warmup_steps == 1
assert config.fuse_qkv is True
# Parallelism defaults
assert config.dit_dp_size == 1
......@@ -532,10 +537,12 @@ class ConcurrencyTracker:
with self._lock:
self._active_count -= 1
# Return fake frames (shape: [num_frames, H, W, C])
import numpy as np
return np.zeros((4, 64, 64, 3), dtype=np.uint8)
# Return a mock MediaOutput with a video tensor
return SimpleNamespace(
video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
image=None,
audio=None,
)
class TestVideoHandlerConcurrency:
......@@ -660,16 +667,17 @@ class TestVideoHandlerResponseFormats:
def _make_handler(self):
"""Create a handler with mocked engine and fs."""
import numpy as np
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
mock_engine = MagicMock()
mock_engine.generate = MagicMock(
return_value=np.zeros((4, 64, 64, 3), dtype=np.uint8)
mock_output = SimpleNamespace(
video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
image=None,
audio=None,
)
mock_engine = MagicMock()
mock_engine.generate = MagicMock(return_value=mock_output)
config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media",
......
......@@ -33,20 +33,19 @@ async def init_video_diffusion_worker(
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
"""
# Check visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen
# branch — not yet in any release. Without this check, users would get a cryptic
# ImportError deep inside DiffusionEngine.initialize().
# Check tensorrt_llm visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM (tensorrt_llm._torch.visual_gen).
# Without this check, users would get a cryptic ImportError deep inside
# DiffusionEngine.initialize().
try:
import visual_gen # noqa: F401
import tensorrt_llm._torch.visual_gen # noqa: F401
except ImportError:
raise ImportError(
"Video diffusion requires the 'visual_gen' package from TensorRT-LLM's "
"feat/visual_gen branch. Install with:\n"
" git clone https://github.com/NVIDIA/TensorRT-LLM.git\n"
" cd TensorRT-LLM && git checkout feat/visual_gen\n"
" cd tensorrt_llm/visual_gen && pip install -e .\n"
"See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen"
"Video diffusion requires TensorRT-LLM with visual_gen support.\n"
"The visual_gen module is at tensorrt_llm._torch.visual_gen.\n"
"Install TensorRT-LLM with AIGV support:\n"
" pip install tensorrt_llm\n"
"See: https://github.com/NVIDIA/TensorRT-LLM"
) from None
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
......@@ -55,6 +54,13 @@ async def init_video_diffusion_worker(
logging.info(f"Initializing video diffusion worker with config: {config}")
# Parse skip_components from comma-separated string to list
skip_components = (
[c.strip() for c in config.skip_components.split(",") if c.strip()]
if config.skip_components
else []
)
# Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig(
namespace=config.namespace,
......@@ -65,6 +71,8 @@ async def init_video_diffusion_worker(
event_plane=config.event_plane,
model_path=config.model,
served_model_name=config.served_model_name,
torch_dtype=config.torch_dtype,
revision=config.revision,
media_output_fs_url=config.media_output_fs_url,
media_output_http_url=config.media_output_http_url,
default_height=config.default_height,
......@@ -72,19 +80,34 @@ async def init_video_diffusion_worker(
default_num_frames=config.default_num_frames,
default_num_inference_steps=config.default_num_inference_steps,
default_guidance_scale=config.default_guidance_scale,
enable_teacache=config.enable_teacache,
teacache_thresh=config.teacache_thresh,
attn_type=config.attn_type,
linear_type=config.linear_type,
# Pipeline optimization
disable_torch_compile=config.disable_torch_compile,
torch_compile_mode=config.torch_compile_mode,
enable_fullgraph=config.enable_fullgraph,
fuse_qkv=config.fuse_qkv,
enable_cuda_graph=config.enable_cuda_graph,
enable_layerwise_nvtx_marker=config.enable_layerwise_nvtx_marker,
warmup_steps=config.warmup_steps,
# Attention
attn_backend=config.attn_backend,
# Quantization
quant_algo=config.quant_algo,
quant_dynamic=config.quant_dynamic,
# TeaCache
enable_teacache=config.enable_teacache,
teacache_use_ret_steps=config.teacache_use_ret_steps,
teacache_thresh=config.teacache_thresh,
# Parallelism
dit_dp_size=config.dit_dp_size,
dit_tp_size=config.dit_tp_size,
dit_ulysses_size=config.dit_ulysses_size,
dit_ring_size=config.dit_ring_size,
dit_cfg_size=config.dit_cfg_size,
dit_fsdp_size=config.dit_fsdp_size,
# Offloading
enable_async_cpu_offload=config.enable_async_cpu_offload,
# Component loading
skip_components=skip_components,
)
# Get the endpoint from the runtime
......
......@@ -216,11 +216,10 @@ Dynamo supports video generation using diffusion models through the `--modality
### Requirements
- **visual_gen**: Part of TensorRT-LLM, located at `tensorrt_llm/visual_gen/`. Currently available **only** on the [`feat/visual_gen`](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch (not yet merged to main or any release). Install from source:
- **TensorRT-LLM with visual_gen**: The `visual_gen` module is part of TensorRT-LLM (`tensorrt_llm._torch.visual_gen`). Install TensorRT-LLM following the [official instructions](https://github.com/NVIDIA/TensorRT-LLM#installation).
- **imageio with ffmpeg**: Required for encoding generated frames to MP4 video:
```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM && git checkout feat/visual_gen
cd tensorrt_llm/visual_gen && pip install -e .
pip install imageio[ffmpeg]
```
- **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment