Unverified Commit 8de1c3e0 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Update TRTLLM video diffusion for `visual_gen` module restructuring (PR #11462) (#6516)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent 78436fbf
...@@ -242,6 +242,22 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -242,6 +242,22 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=float, arg_type=float,
help="Default CFG guidance scale.", help="Default CFG guidance scale.",
) )
add_argument(
diffusion_group,
flag_name="--torch-dtype",
env_var="DYN_TRTLLM_TORCH_DTYPE",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Torch dtype for model loading. bfloat16 recommended for Ampere+ GPUs.",
)
add_argument(
diffusion_group,
flag_name="--revision",
env_var="DYN_TRTLLM_REVISION",
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA) for model download.",
)
add_negatable_bool_argument( add_negatable_bool_argument(
diffusion_group, diffusion_group,
flag_name="--enable-teacache", flag_name="--enable-teacache",
...@@ -249,6 +265,13 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -249,6 +265,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False, default=False,
help="Enable TeaCache optimization for faster generation.", help="Enable TeaCache optimization for faster generation.",
) )
add_negatable_bool_argument(
diffusion_group,
flag_name="--teacache-use-ret-steps",
env_var="DYN_TRTLLM_TEACACHE_USE_RET_STEPS",
default=True,
help="Use retention steps for TeaCache.",
)
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--teacache-thresh", flag_name="--teacache-thresh",
...@@ -259,24 +282,33 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -259,24 +282,33 @@ class DynamoTrtllmArgGroup(ArgGroup):
) )
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--attn-type", flag_name="--attn-backend",
env_var="DYN_TRTLLM_ATTN_TYPE", env_var="DYN_TRTLLM_ATTN_BACKEND",
default="default", default="VANILLA",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"], choices=["VANILLA", "TRTLLM"],
help="Attention type for diffusion models.", help="Attention backend for diffusion models. VANILLA = PyTorch SDPA, TRTLLM = TensorRT-LLM kernels.",
) )
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--linear-type", flag_name="--quant-algo",
env_var="DYN_TRTLLM_LINEAR_TYPE", env_var="DYN_TRTLLM_QUANT_ALGO",
default="default", default=None,
choices=[ choices=[
"default", "FP8",
"trtllm-fp8-blockwise", "FP8_BLOCK_SCALES",
"trtllm-fp8-per-tensor", "NVFP4",
"trtllm-nvfp4", "W4A16_AWQ",
"W4A8_AWQ",
"W8A8_SQ_PER_CHANNEL",
], ],
help="Linear type for quantization.", help="Quantization algorithm for diffusion models. BF16 weights are quantized on-the-fly during loading.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--quant-dynamic",
env_var="DYN_TRTLLM_QUANT_DYNAMIC",
default=True,
help="Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading).",
) )
add_negatable_bool_argument( add_negatable_bool_argument(
diffusion_group, diffusion_group,
...@@ -293,6 +325,42 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -293,6 +325,42 @@ class DynamoTrtllmArgGroup(ArgGroup):
choices=["default", "reduce-overhead", "max-autotune"], choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode.", help="torch.compile mode.",
) )
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-fullgraph",
env_var="DYN_TRTLLM_ENABLE_FULLGRAPH",
default=False,
help="Enable torch.compile fullgraph mode (stricter but potentially faster).",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--fuse-qkv",
env_var="DYN_TRTLLM_FUSE_QKV",
default=True,
help="Enable QKV fusion for transformer attention layers.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-cuda-graph",
env_var="DYN_TRTLLM_ENABLE_CUDA_GRAPH",
default=False,
help="Enable CUDA graph capture for transformer forward passes. Mutually exclusive with torch.compile.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-layerwise-nvtx-marker",
env_var="DYN_TRTLLM_ENABLE_LAYERWISE_NVTX_MARKER",
default=False,
help="Enable per-layer NVTX markers for profiling with Nsight Systems.",
)
add_argument(
diffusion_group,
flag_name="--warmup-steps",
env_var="DYN_TRTLLM_WARMUP_STEPS",
default=1,
arg_type=int,
help="Number of denoising steps to run during warmup (0 to disable).",
)
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--dit-dp-size", flag_name="--dit-dp-size",
...@@ -348,6 +416,17 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -348,6 +416,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False, default=False,
help="Enable async CPU offload for memory efficiency.", help="Enable async CPU offload for memory efficiency.",
) )
add_argument(
diffusion_group,
flag_name="--skip-components",
env_var="DYN_TRTLLM_SKIP_COMPONENTS",
default="",
help=(
"Comma-separated list of pipeline components to skip loading. "
"Valid values: transformer, vae, text_encoder, tokenizer, scheduler, "
"image_encoder, image_processor."
),
)
class DynamoTrtllmConfig(ConfigBase): class DynamoTrtllmConfig(ConfigBase):
...@@ -383,12 +462,21 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -383,12 +462,21 @@ class DynamoTrtllmConfig(ConfigBase):
default_num_frames: int default_num_frames: int
default_num_inference_steps: int default_num_inference_steps: int
default_guidance_scale: float default_guidance_scale: float
torch_dtype: str
revision: Optional[str] = None
enable_teacache: bool enable_teacache: bool
teacache_use_ret_steps: bool
teacache_thresh: float teacache_thresh: float
attn_type: str attn_backend: str
linear_type: str quant_algo: Optional[str]
quant_dynamic: bool
disable_torch_compile: bool disable_torch_compile: bool
torch_compile_mode: str torch_compile_mode: str
enable_fullgraph: bool
fuse_qkv: bool
enable_cuda_graph: bool
enable_layerwise_nvtx_marker: bool
warmup_steps: int
dit_dp_size: int dit_dp_size: int
dit_tp_size: int dit_tp_size: int
dit_ulysses_size: int dit_ulysses_size: int
...@@ -396,6 +484,7 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -396,6 +484,7 @@ class DynamoTrtllmConfig(ConfigBase):
dit_cfg_size: int dit_cfg_size: int
dit_fsdp_size: int dit_fsdp_size: int
enable_async_cpu_offload: bool enable_async_cpu_offload: bool
skip_components: str
def validate(self) -> None: def validate(self) -> None:
if isinstance(self.disaggregation_mode, str): if isinstance(self.disaggregation_mode, str):
......
...@@ -5,9 +5,16 @@ ...@@ -5,9 +5,16 @@
This module defines the DiffusionConfig dataclass used for configuring This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers. video and image diffusion workers.
Fields map to TensorRT-LLM's DiffusionArgs sub-configs:
- PipelineConfig: torch_compile, CUDA graph, warmup, offloading, fuse_qkv
- AttentionConfig: attention backend (VANILLA, TRTLLM)
- ParallelConfig: dit_*_size parallelism dimensions
- TeaCacheConfig: caching optimization
- QuantConfig: quantization algorithm and dynamic flags
""" """
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from dynamo.common.utils.namespace import get_worker_namespace from dynamo.common.utils.namespace import get_worker_namespace
...@@ -23,7 +30,7 @@ class DiffusionConfig: ...@@ -23,7 +30,7 @@ class DiffusionConfig:
"""Configuration for diffusion model workers (video/image generation). """Configuration for diffusion model workers (video/image generation).
This configuration is used by DiffusionEngine and diffusion handlers. This configuration is used by DiffusionEngine and diffusion handlers.
It can be populated from command-line arguments in trtllm_utils.py. It can be populated from command-line arguments in backend_args.py.
""" """
# Dynamo runtime config # Dynamo runtime config
...@@ -41,6 +48,8 @@ class DiffusionConfig: ...@@ -41,6 +48,8 @@ class DiffusionConfig:
# bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.) # bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.)
# float16 can be used on older GPUs (V100, etc.) # float16 can be used on older GPUs (V100, etc.)
torch_dtype: str = "bfloat16" torch_dtype: str = "bfloat16"
# HuggingFace Hub revision (branch, tag, or commit SHA) for model download.
revision: Optional[str] = None
# Media storage # Media storage
media_output_fs_url: str = "file:///tmp/dynamo_media" media_output_fs_url: str = "file:///tmp/dynamo_media"
...@@ -58,16 +67,39 @@ class DiffusionConfig: ...@@ -58,16 +67,39 @@ class DiffusionConfig:
default_num_inference_steps: int = 50 default_num_inference_steps: int = 50
default_guidance_scale: float = 5.0 default_guidance_scale: float = 5.0
# visual_gen optimization config # ── Pipeline optimization config (maps to PipelineConfig) ──
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Enable torch.compile fullgraph mode (stricter but potentially faster)
enable_fullgraph: bool = False
# QKV fusion for transformer attention layers
fuse_qkv: bool = True
# CUDA graph capture for transformer forward passes
# (mutually exclusive with torch.compile — torch.compile takes priority)
enable_cuda_graph: bool = False
# Enable per-layer NVTX markers for profiling
enable_layerwise_nvtx_marker: bool = False
# Number of denoising steps to run during warmup (0 to disable)
warmup_steps: int = 1
# ── Attention config (maps to AttentionConfig) ──
# Attention backend: "VANILLA" (PyTorch SDPA) or "TRTLLM"
attn_backend: str = "VANILLA"
# ── Quantization config (maps to DiffusionArgs.quant_config) ──
# Quantization algorithm. Options:
# None (no quantization), "FP8", "FP8_BLOCK_SCALES", "NVFP4",
# "W4A16_AWQ", "W4A8_AWQ", "W8A8_SQ_PER_CHANNEL"
quant_algo: Optional[str] = None
# Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading)
quant_dynamic: bool = True
# ── TeaCache optimization config (maps to TeaCacheConfig) ──
enable_teacache: bool = False enable_teacache: bool = False
teacache_use_ret_steps: bool = True teacache_use_ret_steps: bool = True
teacache_thresh: float = 0.2 teacache_thresh: float = 0.2
attn_type: str = "default"
linear_type: str = "default"
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Parallelism config (DiTParallelConfig) # ── Parallelism config (maps to ParallelConfig) ──
dit_dp_size: int = 1 dit_dp_size: int = 1
dit_tp_size: int = 1 dit_tp_size: int = 1
dit_ulysses_size: int = 1 dit_ulysses_size: int = 1
...@@ -75,9 +107,14 @@ class DiffusionConfig: ...@@ -75,9 +107,14 @@ class DiffusionConfig:
dit_cfg_size: int = 1 dit_cfg_size: int = 1
dit_fsdp_size: int = 1 dit_fsdp_size: int = 1
# CPU offload config # ── Offloading config (maps to PipelineConfig) ──
enable_async_cpu_offload: bool = False enable_async_cpu_offload: bool = False
visual_gen_block_cpu_offload_stride: int = 1
# ── Component loading options ──
# Components to skip loading (e.g., ["text_encoder", "vae"]).
# Valid values: "transformer", "vae", "text_encoder", "tokenizer",
# "scheduler", "image_encoder", "image_processor"
skip_components: list[str] = field(default_factory=list)
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -93,8 +130,10 @@ class DiffusionConfig: ...@@ -93,8 +130,10 @@ class DiffusionConfig:
f"default_num_frames={self.default_num_frames}, " f"default_num_frames={self.default_num_frames}, "
f"default_num_inference_steps={self.default_num_inference_steps}, " f"default_num_inference_steps={self.default_num_inference_steps}, "
f"enable_teacache={self.enable_teacache}, " f"enable_teacache={self.enable_teacache}, "
f"attn_type={self.attn_type}, " f"attn_backend={self.attn_backend}, "
f"linear_type={self.linear_type}, " f"quant_algo={self.quant_algo}, "
f"enable_cuda_graph={self.enable_cuda_graph}, "
f"warmup_steps={self.warmup_steps}, "
f"dit_dp_size={self.dit_dp_size}, " f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})" f"dit_tp_size={self.dit_tp_size})"
) )
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Engine modules for TensorRT-LLM backend. """Engine modules for TensorRT-LLM backend.
This module provides engine wrappers for various generative models: This module provides engine wrappers for various generative models:
- DiffusionEngine: Generic wrapper for visual_gen diffusion pipelines - DiffusionEngine: Generic wrapper for TensorRT-LLM visual_gen diffusion pipelines
""" """
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
"""Video generation request handler for TensorRT-LLM backend. """Video generation request handler for TensorRT-LLM backend.
This handler processes video generation requests using diffusion models. This handler processes video generation requests using diffusion models.
It handles MediaOutput from TensorRT-LLM's visual_gen pipelines, which
can contain video, image, and/or audio tensors depending on the model.
""" """
import asyncio import asyncio
...@@ -32,9 +34,14 @@ logger = logging.getLogger(__name__) ...@@ -32,9 +34,14 @@ logger = logging.getLogger(__name__)
class VideoGenerationHandler(BaseGenerativeHandler): class VideoGenerationHandler(BaseGenerativeHandler):
"""Handler for video generation requests. """Handler for video generation requests.
This handler receives video generation requests, runs the diffusion This handler receives generation requests, runs the diffusion pipeline
pipeline via DiffusionEngine, encodes the output to MP4, and returns via DiffusionEngine, encodes the output to the appropriate media format,
the video URL or base64-encoded data. and returns the media URL or base64-encoded data.
Supports MediaOutput with:
- video: torch.Tensor (num_frames, H, W, 3) uint8 → encoded as MP4
- image: logged as unsupported (use an image handler instead)
- audio: logged (future: mux into MP4)
Inherits from BaseGenerativeHandler to share the common interface with Inherits from BaseGenerativeHandler to share the common interface with
LLM handlers. LLM handlers.
...@@ -59,8 +66,8 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -59,8 +66,8 @@ class VideoGenerationHandler(BaseGenerativeHandler):
) )
self.media_output_fs = get_fs(config.media_output_fs_url) self.media_output_fs = get_fs(config.media_output_fs_url)
self.media_output_http_url = config.media_output_http_url self.media_output_http_url = config.media_output_http_url
# Serialize pipeline access — visual_gen is not thread-safe (global # Serialize pipeline access — the diffusion pipeline is not thread-safe
# singleton configs, mutable instance state, unprotected CUDA graph cache). # (mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event # asyncio.Lock suspends waiting coroutines cooperatively so the event
# loop stays free for health checks and signal handling. # loop stays free for health checks and signal handling.
self._generate_lock = asyncio.Lock() self._generate_lock = asyncio.Lock()
...@@ -159,21 +166,26 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -159,21 +166,26 @@ class VideoGenerationHandler(BaseGenerativeHandler):
async def generate( async def generate(
self, request: dict[str, Any], context: Context self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]: ) -> AsyncGenerator[dict[str, Any], None]:
"""Generate video from request. """Generate video/image from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint(). This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Handles MediaOutput from the pipeline:
- video tensor → MP4
- image tensor → unsupported (raises error)
- audio tensor → unsupported (raises error)
Args: Args:
request: Request dictionary with video generation parameters. request: Request dictionary with generation parameters.
context: Dynamo context for request tracking. context: Dynamo context for request tracking.
Yields: Yields:
Response dictionary with generated video data. Response dictionary with generated media data.
""" """
start_time = time.time() start_time = time.time()
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
logger.info(f"Received video generation request: {request_id}") logger.info(f"Received generation request: {request_id}")
try: try:
# Parse request # Parse request
...@@ -202,11 +214,11 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -202,11 +214,11 @@ class VideoGenerationHandler(BaseGenerativeHandler):
# Run generation in thread pool (blocking operation). # Run generation in thread pool (blocking operation).
# Lock ensures only one request uses the pipeline at a time. # Lock ensures only one request uses the pipeline at a time.
# TODO: Add cancellation support. This requires: # TODO: Add cancellation support. This requires:
# 1. visual_gen to expose a cancellation hook in the denoising loop # 1. The pipeline to expose a cancellation hook in the denoising loop
# 2. Passing a cancellation token/event to engine.generate() # 2. Passing a cancellation token/event to engine.generate()
# 3. Checking context.cancelled() and propagating to the pipeline # 3. Checking context.cancelled() and propagating to the pipeline
async with self._generate_lock: async with self._generate_lock:
frames = await asyncio.to_thread( output = await asyncio.to_thread(
self.engine.generate, self.engine.generate,
prompt=req.prompt, prompt=req.prompt,
negative_prompt=nvext.negative_prompt, negative_prompt=nvext.negative_prompt,
...@@ -218,15 +230,47 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -218,15 +230,47 @@ class VideoGenerationHandler(BaseGenerativeHandler):
seed=nvext.seed, seed=nvext.seed,
) )
if output is None:
raise RuntimeError("Pipeline returned no output (MediaOutput is None)")
# Determine output format # Determine output format
response_format = req.response_format or "url" response_format = req.response_format or "url"
fps = nvext.fps or self.config.default_fps fps = nvext.fps or self.config.default_fps
# Encode frames to MP4 bytes in memory # Encode media based on what the pipeline returned
video_bytes = await asyncio.to_thread(encode_to_mp4_bytes, frames, fps=fps) if output.video is not None:
# Video output: torch.Tensor (num_frames, H, W, 3) uint8 → MP4
frames_np = output.video.cpu().numpy()
logger.info(
f"Request {request_id}: encoding video output "
f"(shape={frames_np.shape}) to MP4 at {fps} fps"
)
video_bytes = await asyncio.to_thread(
encode_to_mp4_bytes, frames_np, fps=fps
)
elif output.image is not None:
raise RuntimeError(
"Pipeline returned image-only output, but this handler "
"only supports video. Use an image generation handler instead."
)
# Log audio if present (unsupported)
elif output.audio is not None:
raise RuntimeError(
"Pipeline returned audio-only output, but this handler "
"only supports video. Use an audio generation handler instead."
)
else:
raise RuntimeError(
"Pipeline returned MediaOutput with no video or image or audio data. "
f"MediaOutput fields: video={output.video is not None}, "
f"image={output.image is not None}, audio={output.audio is not None}"
)
# Return media via URL or base64
if response_format == "url": if response_format == "url":
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4" storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs( video_url = await upload_to_fs(
self.media_output_fs, self.media_output_fs,
...@@ -236,7 +280,6 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -236,7 +280,6 @@ class VideoGenerationHandler(BaseGenerativeHandler):
) )
video_data = VideoData(url=video_url) video_data = VideoData(url=video_url)
else: else:
# Encode to base64
b64_video = base64.b64encode(video_bytes).decode("utf-8") b64_video = base64.b64encode(video_bytes).decode("utf-8")
video_data = VideoData(b64_json=b64_video) video_data = VideoData(b64_json=b64_video)
......
...@@ -13,10 +13,12 @@ import asyncio ...@@ -13,10 +13,12 @@ import asyncio
import threading import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from types import SimpleNamespace
from typing import Optional from typing import Optional
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch
from dynamo.common.protocols.video_protocol import ( from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
...@@ -104,8 +106,11 @@ class TestDiffusionConfig: ...@@ -104,8 +106,11 @@ class TestDiffusionConfig:
# Optimization defaults # Optimization defaults
assert config.enable_teacache is False assert config.enable_teacache is False
assert config.attn_type == "default" assert config.attn_backend == "VANILLA"
assert config.linear_type == "default" assert config.quant_algo is None
assert config.enable_cuda_graph is False
assert config.warmup_steps == 1
assert config.fuse_qkv is True
# Parallelism defaults # Parallelism defaults
assert config.dit_dp_size == 1 assert config.dit_dp_size == 1
...@@ -532,10 +537,12 @@ class ConcurrencyTracker: ...@@ -532,10 +537,12 @@ class ConcurrencyTracker:
with self._lock: with self._lock:
self._active_count -= 1 self._active_count -= 1
# Return fake frames (shape: [num_frames, H, W, C]) # Return a mock MediaOutput with a video tensor
import numpy as np return SimpleNamespace(
video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
return np.zeros((4, 64, 64, 3), dtype=np.uint8) image=None,
audio=None,
)
class TestVideoHandlerConcurrency: class TestVideoHandlerConcurrency:
...@@ -660,16 +667,17 @@ class TestVideoHandlerResponseFormats: ...@@ -660,16 +667,17 @@ class TestVideoHandlerResponseFormats:
def _make_handler(self): def _make_handler(self):
"""Create a handler with mocked engine and fs.""" """Create a handler with mocked engine and fs."""
import numpy as np
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import ( from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler, VideoGenerationHandler,
) )
mock_engine = MagicMock() mock_output = SimpleNamespace(
mock_engine.generate = MagicMock( video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
return_value=np.zeros((4, 64, 64, 3), dtype=np.uint8) image=None,
audio=None,
) )
mock_engine = MagicMock()
mock_engine.generate = MagicMock(return_value=mock_output)
config = DiffusionConfig( config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media", media_output_fs_url="file:///tmp/test_media",
......
...@@ -33,20 +33,19 @@ async def init_video_diffusion_worker( ...@@ -33,20 +33,19 @@ async def init_video_diffusion_worker(
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown. shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
""" """
# Check visual_gen availability early with a clear error message. # Check tensorrt_llm visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen # visual_gen is part of TensorRT-LLM (tensorrt_llm._torch.visual_gen).
# branch — not yet in any release. Without this check, users would get a cryptic # Without this check, users would get a cryptic ImportError deep inside
# ImportError deep inside DiffusionEngine.initialize(). # DiffusionEngine.initialize().
try: try:
import visual_gen # noqa: F401 import tensorrt_llm._torch.visual_gen # noqa: F401
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Video diffusion requires the 'visual_gen' package from TensorRT-LLM's " "Video diffusion requires TensorRT-LLM with visual_gen support.\n"
"feat/visual_gen branch. Install with:\n" "The visual_gen module is at tensorrt_llm._torch.visual_gen.\n"
" git clone https://github.com/NVIDIA/TensorRT-LLM.git\n" "Install TensorRT-LLM with AIGV support:\n"
" cd TensorRT-LLM && git checkout feat/visual_gen\n" " pip install tensorrt_llm\n"
" cd tensorrt_llm/visual_gen && pip install -e .\n" "See: https://github.com/NVIDIA/TensorRT-LLM"
"See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen"
) from None ) from None
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
...@@ -55,6 +54,13 @@ async def init_video_diffusion_worker( ...@@ -55,6 +54,13 @@ async def init_video_diffusion_worker(
logging.info(f"Initializing video diffusion worker with config: {config}") logging.info(f"Initializing video diffusion worker with config: {config}")
# Parse skip_components from comma-separated string to list
skip_components = (
[c.strip() for c in config.skip_components.split(",") if c.strip()]
if config.skip_components
else []
)
# Build DiffusionConfig from the main Config # Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig( diffusion_config = DiffusionConfig(
namespace=config.namespace, namespace=config.namespace,
...@@ -65,6 +71,8 @@ async def init_video_diffusion_worker( ...@@ -65,6 +71,8 @@ async def init_video_diffusion_worker(
event_plane=config.event_plane, event_plane=config.event_plane,
model_path=config.model, model_path=config.model,
served_model_name=config.served_model_name, served_model_name=config.served_model_name,
torch_dtype=config.torch_dtype,
revision=config.revision,
media_output_fs_url=config.media_output_fs_url, media_output_fs_url=config.media_output_fs_url,
media_output_http_url=config.media_output_http_url, media_output_http_url=config.media_output_http_url,
default_height=config.default_height, default_height=config.default_height,
...@@ -72,19 +80,34 @@ async def init_video_diffusion_worker( ...@@ -72,19 +80,34 @@ async def init_video_diffusion_worker(
default_num_frames=config.default_num_frames, default_num_frames=config.default_num_frames,
default_num_inference_steps=config.default_num_inference_steps, default_num_inference_steps=config.default_num_inference_steps,
default_guidance_scale=config.default_guidance_scale, default_guidance_scale=config.default_guidance_scale,
enable_teacache=config.enable_teacache, # Pipeline optimization
teacache_thresh=config.teacache_thresh,
attn_type=config.attn_type,
linear_type=config.linear_type,
disable_torch_compile=config.disable_torch_compile, disable_torch_compile=config.disable_torch_compile,
torch_compile_mode=config.torch_compile_mode, torch_compile_mode=config.torch_compile_mode,
enable_fullgraph=config.enable_fullgraph,
fuse_qkv=config.fuse_qkv,
enable_cuda_graph=config.enable_cuda_graph,
enable_layerwise_nvtx_marker=config.enable_layerwise_nvtx_marker,
warmup_steps=config.warmup_steps,
# Attention
attn_backend=config.attn_backend,
# Quantization
quant_algo=config.quant_algo,
quant_dynamic=config.quant_dynamic,
# TeaCache
enable_teacache=config.enable_teacache,
teacache_use_ret_steps=config.teacache_use_ret_steps,
teacache_thresh=config.teacache_thresh,
# Parallelism
dit_dp_size=config.dit_dp_size, dit_dp_size=config.dit_dp_size,
dit_tp_size=config.dit_tp_size, dit_tp_size=config.dit_tp_size,
dit_ulysses_size=config.dit_ulysses_size, dit_ulysses_size=config.dit_ulysses_size,
dit_ring_size=config.dit_ring_size, dit_ring_size=config.dit_ring_size,
dit_cfg_size=config.dit_cfg_size, dit_cfg_size=config.dit_cfg_size,
dit_fsdp_size=config.dit_fsdp_size, dit_fsdp_size=config.dit_fsdp_size,
# Offloading
enable_async_cpu_offload=config.enable_async_cpu_offload, enable_async_cpu_offload=config.enable_async_cpu_offload,
# Component loading
skip_components=skip_components,
) )
# Get the endpoint from the runtime # Get the endpoint from the runtime
......
...@@ -216,11 +216,10 @@ Dynamo supports video generation using diffusion models through the `--modality ...@@ -216,11 +216,10 @@ Dynamo supports video generation using diffusion models through the `--modality
### Requirements ### Requirements
- **visual_gen**: Part of TensorRT-LLM, located at `tensorrt_llm/visual_gen/`. Currently available **only** on the [`feat/visual_gen`](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch (not yet merged to main or any release). Install from source: - **TensorRT-LLM with visual_gen**: The `visual_gen` module is part of TensorRT-LLM (`tensorrt_llm._torch.visual_gen`). Install TensorRT-LLM following the [official instructions](https://github.com/NVIDIA/TensorRT-LLM#installation).
- **imageio with ffmpeg**: Required for encoding generated frames to MP4 video:
```bash ```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git pip install imageio[ffmpeg]
cd TensorRT-LLM && git checkout feat/visual_gen
cd tensorrt_llm/visual_gen && pip install -e .
``` ```
- **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version. - **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment