Unverified Commit 8de1c3e0 authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Update TRTLLM video diffusion for `visual_gen` module restructuring (PR #11462) (#6516)


Signed-off-by: default avatarIndrajit Bhosale <iamindrajitb@gmail.com>
parent 78436fbf
...@@ -242,6 +242,22 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -242,6 +242,22 @@ class DynamoTrtllmArgGroup(ArgGroup):
arg_type=float, arg_type=float,
help="Default CFG guidance scale.", help="Default CFG guidance scale.",
) )
add_argument(
diffusion_group,
flag_name="--torch-dtype",
env_var="DYN_TRTLLM_TORCH_DTYPE",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Torch dtype for model loading. bfloat16 recommended for Ampere+ GPUs.",
)
add_argument(
diffusion_group,
flag_name="--revision",
env_var="DYN_TRTLLM_REVISION",
default=None,
help="HuggingFace Hub revision (branch, tag, or commit SHA) for model download.",
)
add_negatable_bool_argument( add_negatable_bool_argument(
diffusion_group, diffusion_group,
flag_name="--enable-teacache", flag_name="--enable-teacache",
...@@ -249,6 +265,13 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -249,6 +265,13 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False, default=False,
help="Enable TeaCache optimization for faster generation.", help="Enable TeaCache optimization for faster generation.",
) )
add_negatable_bool_argument(
diffusion_group,
flag_name="--teacache-use-ret-steps",
env_var="DYN_TRTLLM_TEACACHE_USE_RET_STEPS",
default=True,
help="Use retention steps for TeaCache.",
)
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--teacache-thresh", flag_name="--teacache-thresh",
...@@ -259,24 +282,33 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -259,24 +282,33 @@ class DynamoTrtllmArgGroup(ArgGroup):
) )
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--attn-type", flag_name="--attn-backend",
env_var="DYN_TRTLLM_ATTN_TYPE", env_var="DYN_TRTLLM_ATTN_BACKEND",
default="default", default="VANILLA",
choices=["default", "sage-attn", "sparse-videogen", "sparse-videogen2"], choices=["VANILLA", "TRTLLM"],
help="Attention type for diffusion models.", help="Attention backend for diffusion models. VANILLA = PyTorch SDPA, TRTLLM = TensorRT-LLM kernels.",
) )
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--linear-type", flag_name="--quant-algo",
env_var="DYN_TRTLLM_LINEAR_TYPE", env_var="DYN_TRTLLM_QUANT_ALGO",
default="default", default=None,
choices=[ choices=[
"default", "FP8",
"trtllm-fp8-blockwise", "FP8_BLOCK_SCALES",
"trtllm-fp8-per-tensor", "NVFP4",
"trtllm-nvfp4", "W4A16_AWQ",
"W4A8_AWQ",
"W8A8_SQ_PER_CHANNEL",
], ],
help="Linear type for quantization.", help="Quantization algorithm for diffusion models. BF16 weights are quantized on-the-fly during loading.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--quant-dynamic",
env_var="DYN_TRTLLM_QUANT_DYNAMIC",
default=True,
help="Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading).",
) )
add_negatable_bool_argument( add_negatable_bool_argument(
diffusion_group, diffusion_group,
...@@ -293,6 +325,42 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -293,6 +325,42 @@ class DynamoTrtllmArgGroup(ArgGroup):
choices=["default", "reduce-overhead", "max-autotune"], choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode.", help="torch.compile mode.",
) )
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-fullgraph",
env_var="DYN_TRTLLM_ENABLE_FULLGRAPH",
default=False,
help="Enable torch.compile fullgraph mode (stricter but potentially faster).",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--fuse-qkv",
env_var="DYN_TRTLLM_FUSE_QKV",
default=True,
help="Enable QKV fusion for transformer attention layers.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-cuda-graph",
env_var="DYN_TRTLLM_ENABLE_CUDA_GRAPH",
default=False,
help="Enable CUDA graph capture for transformer forward passes. Mutually exclusive with torch.compile.",
)
add_negatable_bool_argument(
diffusion_group,
flag_name="--enable-layerwise-nvtx-marker",
env_var="DYN_TRTLLM_ENABLE_LAYERWISE_NVTX_MARKER",
default=False,
help="Enable per-layer NVTX markers for profiling with Nsight Systems.",
)
add_argument(
diffusion_group,
flag_name="--warmup-steps",
env_var="DYN_TRTLLM_WARMUP_STEPS",
default=1,
arg_type=int,
help="Number of denoising steps to run during warmup (0 to disable).",
)
add_argument( add_argument(
diffusion_group, diffusion_group,
flag_name="--dit-dp-size", flag_name="--dit-dp-size",
...@@ -348,6 +416,17 @@ class DynamoTrtllmArgGroup(ArgGroup): ...@@ -348,6 +416,17 @@ class DynamoTrtllmArgGroup(ArgGroup):
default=False, default=False,
help="Enable async CPU offload for memory efficiency.", help="Enable async CPU offload for memory efficiency.",
) )
add_argument(
diffusion_group,
flag_name="--skip-components",
env_var="DYN_TRTLLM_SKIP_COMPONENTS",
default="",
help=(
"Comma-separated list of pipeline components to skip loading. "
"Valid values: transformer, vae, text_encoder, tokenizer, scheduler, "
"image_encoder, image_processor."
),
)
class DynamoTrtllmConfig(ConfigBase): class DynamoTrtllmConfig(ConfigBase):
...@@ -383,12 +462,21 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -383,12 +462,21 @@ class DynamoTrtllmConfig(ConfigBase):
default_num_frames: int default_num_frames: int
default_num_inference_steps: int default_num_inference_steps: int
default_guidance_scale: float default_guidance_scale: float
torch_dtype: str
revision: Optional[str] = None
enable_teacache: bool enable_teacache: bool
teacache_use_ret_steps: bool
teacache_thresh: float teacache_thresh: float
attn_type: str attn_backend: str
linear_type: str quant_algo: Optional[str]
quant_dynamic: bool
disable_torch_compile: bool disable_torch_compile: bool
torch_compile_mode: str torch_compile_mode: str
enable_fullgraph: bool
fuse_qkv: bool
enable_cuda_graph: bool
enable_layerwise_nvtx_marker: bool
warmup_steps: int
dit_dp_size: int dit_dp_size: int
dit_tp_size: int dit_tp_size: int
dit_ulysses_size: int dit_ulysses_size: int
...@@ -396,6 +484,7 @@ class DynamoTrtllmConfig(ConfigBase): ...@@ -396,6 +484,7 @@ class DynamoTrtllmConfig(ConfigBase):
dit_cfg_size: int dit_cfg_size: int
dit_fsdp_size: int dit_fsdp_size: int
enable_async_cpu_offload: bool enable_async_cpu_offload: bool
skip_components: str
def validate(self) -> None: def validate(self) -> None:
if isinstance(self.disaggregation_mode, str): if isinstance(self.disaggregation_mode, str):
......
...@@ -5,9 +5,16 @@ ...@@ -5,9 +5,16 @@
This module defines the DiffusionConfig dataclass used for configuring This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers. video and image diffusion workers.
Fields map to TensorRT-LLM's DiffusionArgs sub-configs:
- PipelineConfig: torch_compile, CUDA graph, warmup, offloading, fuse_qkv
- AttentionConfig: attention backend (VANILLA, TRTLLM)
- ParallelConfig: dit_*_size parallelism dimensions
- TeaCacheConfig: caching optimization
- QuantConfig: quantization algorithm and dynamic flags
""" """
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from dynamo.common.utils.namespace import get_worker_namespace from dynamo.common.utils.namespace import get_worker_namespace
...@@ -23,7 +30,7 @@ class DiffusionConfig: ...@@ -23,7 +30,7 @@ class DiffusionConfig:
"""Configuration for diffusion model workers (video/image generation). """Configuration for diffusion model workers (video/image generation).
This configuration is used by DiffusionEngine and diffusion handlers. This configuration is used by DiffusionEngine and diffusion handlers.
It can be populated from command-line arguments in trtllm_utils.py. It can be populated from command-line arguments in backend_args.py.
""" """
# Dynamo runtime config # Dynamo runtime config
...@@ -41,6 +48,8 @@ class DiffusionConfig: ...@@ -41,6 +48,8 @@ class DiffusionConfig:
# bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.) # bfloat16 is recommended for Ampere+ GPUs (A100, H100, etc.)
# float16 can be used on older GPUs (V100, etc.) # float16 can be used on older GPUs (V100, etc.)
torch_dtype: str = "bfloat16" torch_dtype: str = "bfloat16"
# HuggingFace Hub revision (branch, tag, or commit SHA) for model download.
revision: Optional[str] = None
# Media storage # Media storage
media_output_fs_url: str = "file:///tmp/dynamo_media" media_output_fs_url: str = "file:///tmp/dynamo_media"
...@@ -58,16 +67,39 @@ class DiffusionConfig: ...@@ -58,16 +67,39 @@ class DiffusionConfig:
default_num_inference_steps: int = 50 default_num_inference_steps: int = 50
default_guidance_scale: float = 5.0 default_guidance_scale: float = 5.0
# visual_gen optimization config # ── Pipeline optimization config (maps to PipelineConfig) ──
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Enable torch.compile fullgraph mode (stricter but potentially faster)
enable_fullgraph: bool = False
# QKV fusion for transformer attention layers
fuse_qkv: bool = True
# CUDA graph capture for transformer forward passes
# (mutually exclusive with torch.compile — torch.compile takes priority)
enable_cuda_graph: bool = False
# Enable per-layer NVTX markers for profiling
enable_layerwise_nvtx_marker: bool = False
# Number of denoising steps to run during warmup (0 to disable)
warmup_steps: int = 1
# ── Attention config (maps to AttentionConfig) ──
# Attention backend: "VANILLA" (PyTorch SDPA) or "TRTLLM"
attn_backend: str = "VANILLA"
# ── Quantization config (maps to DiffusionArgs.quant_config) ──
# Quantization algorithm. Options:
# None (no quantization), "FP8", "FP8_BLOCK_SCALES", "NVFP4",
# "W4A16_AWQ", "W4A8_AWQ", "W8A8_SQ_PER_CHANNEL"
quant_algo: Optional[str] = None
# Enable dynamic weight quantization (quantize BF16 weights on-the-fly during loading)
quant_dynamic: bool = True
# ── TeaCache optimization config (maps to TeaCacheConfig) ──
enable_teacache: bool = False enable_teacache: bool = False
teacache_use_ret_steps: bool = True teacache_use_ret_steps: bool = True
teacache_thresh: float = 0.2 teacache_thresh: float = 0.2
attn_type: str = "default"
linear_type: str = "default"
disable_torch_compile: bool = False
torch_compile_mode: str = "default"
# Parallelism config (DiTParallelConfig) # ── Parallelism config (maps to ParallelConfig) ──
dit_dp_size: int = 1 dit_dp_size: int = 1
dit_tp_size: int = 1 dit_tp_size: int = 1
dit_ulysses_size: int = 1 dit_ulysses_size: int = 1
...@@ -75,9 +107,14 @@ class DiffusionConfig: ...@@ -75,9 +107,14 @@ class DiffusionConfig:
dit_cfg_size: int = 1 dit_cfg_size: int = 1
dit_fsdp_size: int = 1 dit_fsdp_size: int = 1
# CPU offload config # ── Offloading config (maps to PipelineConfig) ──
enable_async_cpu_offload: bool = False enable_async_cpu_offload: bool = False
visual_gen_block_cpu_offload_stride: int = 1
# ── Component loading options ──
# Components to skip loading (e.g., ["text_encoder", "vae"]).
# Valid values: "transformer", "vae", "text_encoder", "tokenizer",
# "scheduler", "image_encoder", "image_processor"
skip_components: list[str] = field(default_factory=list)
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
...@@ -93,8 +130,10 @@ class DiffusionConfig: ...@@ -93,8 +130,10 @@ class DiffusionConfig:
f"default_num_frames={self.default_num_frames}, " f"default_num_frames={self.default_num_frames}, "
f"default_num_inference_steps={self.default_num_inference_steps}, " f"default_num_inference_steps={self.default_num_inference_steps}, "
f"enable_teacache={self.enable_teacache}, " f"enable_teacache={self.enable_teacache}, "
f"attn_type={self.attn_type}, " f"attn_backend={self.attn_backend}, "
f"linear_type={self.linear_type}, " f"quant_algo={self.quant_algo}, "
f"enable_cuda_graph={self.enable_cuda_graph}, "
f"warmup_steps={self.warmup_steps}, "
f"dit_dp_size={self.dit_dp_size}, " f"dit_dp_size={self.dit_dp_size}, "
f"dit_tp_size={self.dit_tp_size})" f"dit_tp_size={self.dit_tp_size})"
) )
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Engine modules for TensorRT-LLM backend. """Engine modules for TensorRT-LLM backend.
This module provides engine wrappers for various generative models: This module provides engine wrappers for various generative models:
- DiffusionEngine: Generic wrapper for visual_gen diffusion pipelines - DiffusionEngine: Generic wrapper for TensorRT-LLM visual_gen diffusion pipelines
""" """
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Generic Diffusion Engine wrapper for visual_gen pipelines. """Generic Diffusion Engine wrapper for TensorRT-LLM visual_gen pipelines.
This module provides a unified interface for various diffusion models This module provides a unified interface for various diffusion models
(Wan, Flux, Cosmos, etc.) through a pipeline registry system. (Wan, Flux, Cosmos, etc.) through TensorRT-LLM's AutoPipeline system.
The pipeline type is auto-detected from model_index.json (shipped with every The pipeline type is auto-detected from model_index.json (shipped with every
HuggingFace Diffusers model), eliminating the need for a --model-type flag. HuggingFace Diffusers model), eliminating the need for a --model-type flag.
Requirements: Requirements:
- visual_gen: Part of TensorRT-LLM, located at tensorrt_llm/visual_gen/. - tensorrt_llm with visual_gen support (tensorrt_llm._torch.visual_gen).
Currently on the feat/visual_gen branch (not yet merged to main). See: https://github.com/NVIDIA/TensorRT-LLM
See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen
- See docs/pages/backends/trtllm/README.md for setup instructions. - See docs/pages/backends/trtllm/README.md for setup instructions.
Note on imports: Note on imports:
visual_gen is imported lazily in initialize() because: tensorrt_llm._torch.visual_gen is imported lazily in initialize() because:
1. It's a heavy package that may not be installed in all environments 1. It's a heavy package that may not be installed in all environments
2. Importing at module load would fail if visual_gen is not available 2. Importing at module load would fail if tensorrt_llm is not available
3. This allows the module to be imported for type checking and validation 3. This allows the module to be imported for type checking and validation
without requiring visual_gen to be installed without requiring tensorrt_llm to be installed
""" """
import importlib
import json
import logging import logging
from dataclasses import dataclass import random
from pathlib import Path from enum import Enum
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Optional
import numpy as np
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from tensorrt_llm._torch.visual_gen import DiffusionArgs
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass class DiffusionModality(str, Enum):
class PipelineInfo: """Output modality of a diffusion pipeline."""
"""Auto-detected pipeline information from model_index.json."""
VIDEO = "video_diffusion"
IMAGE = "image_diffusion"
module_path: str # Explicit mapping from TRT-LLM pipeline class names to their output modality.
class_name: str # This replaces brittle substring matching and must be updated when new
modalities: list[str] # pipelines are registered in TRT-LLM's PIPELINE_REGISTRY.
config_overrides: dict[str, Any] _PIPELINE_MODALITY_MAP: dict[str, DiffusionModality] = {
"WanPipeline": DiffusionModality.VIDEO,
"WanImageToVideoPipeline": DiffusionModality.VIDEO,
"FluxPipeline": DiffusionModality.IMAGE,
"LTX2Pipeline": DiffusionModality.VIDEO,
}
# Default when the pipeline is not yet loaded or the class name is unknown.
_DEFAULT_MODALITY = DiffusionModality.VIDEO
class DiffusionEngine: class DiffusionEngine:
"""Generic wrapper for visual_gen diffusion pipelines. """Generic wrapper for TensorRT-LLM visual_gen diffusion pipelines.
This engine provides: This engine provides:
- Auto-detection of pipeline class from model_index.json - Auto-detection of pipeline class from model_index.json via AutoPipeline
- A registry mapping diffusers _class_name to visual_gen pipelines - Loading and initialization through PipelineLoader
- Lazy loading of pipeline modules - Common interface for video/image generation via pipeline.infer()
- Common interface for video/image generation
The old visual_gen standalone package (setup_configs + from_pretrained +
PIPELINE_REGISTRY) has been replaced by TensorRT-LLM's integrated
visual_gen module which uses:
- DiffusionArgs for configuration
- PipelineLoader for model loading (handles MetaInit, weight loading,
quantization, torch.compile, and warmup)
- AutoPipeline for pipeline type auto-detection
- MediaOutput for typed output (video/image/audio torch tensors)
Example: Example:
>>> engine = DiffusionEngine(config) >>> engine = DiffusionEngine(config)
>>> await engine.initialize() >>> await engine.initialize()
>>> frames = engine.generate(prompt="A cat playing piano", ...) >>> output = engine.generate(prompt="A cat playing piano", ...)
>>> output.video # torch.Tensor (num_frames, H, W, 3) uint8
""" """
# Registry: diffusers _class_name -> (module_path, visual_gen_class, supported_modalities)
# The _class_name comes from model_index.json shipped with every HF Diffusers model.
# torch_compile_models is derived dynamically from transformer* keys in model_index.json.
#
# NOTE: This registry is initially focused on Wan text-to-video models.
# Follow-up PRs will extend support for other model families (Flux, Cosmos, etc.)
# which may require additional config fields in DiffusionConfig.
PIPELINE_REGISTRY: dict[str, tuple[str, str, list[str]]] = {
"WanPipeline": (
"visual_gen.pipelines.wan_pipeline",
"ditWanPipeline",
["video_diffusion"],
),
# TODO: Add support for WanImageToVideoPipeline, FluxPipeline, etc.
}
@classmethod
def detect_pipeline_info(cls, model_path: str) -> PipelineInfo:
"""Auto-detect pipeline class from model's model_index.json.
Reads model_index.json (local path or HuggingFace Hub) to determine:
- Which visual_gen pipeline class to use (via _class_name)
- Which transformer models to torch.compile (via transformer* keys)
Args:
model_path: Local path or HuggingFace model identifier.
Returns:
PipelineInfo with module_path, class_name, modalities, and config_overrides.
Raises:
ValueError: If _class_name is not in the registry.
FileNotFoundError: If model_index.json cannot be found locally or on HF Hub.
"""
# Try local path first
local_index = Path(model_path) / "model_index.json"
if local_index.exists():
with open(local_index) as f:
model_index = json.load(f)
else:
# Download from HuggingFace Hub
from huggingface_hub import hf_hub_download
index_path = hf_hub_download(model_path, "model_index.json")
with open(index_path) as f:
model_index = json.load(f)
class_name = model_index.get("_class_name")
if class_name not in cls.PIPELINE_REGISTRY:
supported = list(cls.PIPELINE_REGISTRY.keys())
raise ValueError(
f"Unsupported diffusion pipeline '{class_name}' from model '{model_path}'.\n"
f"Supported pipelines: {', '.join(supported)}\n"
f"Check that model_index.json has a supported _class_name."
)
module_path, vg_class, modalities = cls.PIPELINE_REGISTRY[class_name]
# Derive torch_compile_models from transformer* keys in model_index.json
transformer_keys = sorted(k for k in model_index if k.startswith("transformer"))
torch_compile_models = (
",".join(transformer_keys) if transformer_keys else "transformer"
)
config_overrides = {"torch_compile_models": torch_compile_models}
return PipelineInfo(module_path, vg_class, modalities, config_overrides)
def __init__(self, config: "DiffusionConfig"): def __init__(self, config: "DiffusionConfig"):
"""Initialize the engine with configuration. """Initialize the engine with configuration.
Auto-detects the pipeline type from config.model_path's model_index.json.
Args: Args:
config: Diffusion generation configuration. config: Diffusion generation configuration.
Raises:
ValueError: If the model's pipeline type is not supported.
""" """
info = self.detect_pipeline_info(config.model_path)
self.config = config self.config = config
self._pipeline = None self._pipeline: Optional["BasePipeline"] = None
self._initialized = False self._initialized = False
self._module_path = info.module_path
self._class_name = info.class_name
self._supported_modalities = info.modalities
self._config_overrides = info.config_overrides
async def initialize(self) -> None: async def initialize(self) -> None:
"""Load and configure the diffusion pipeline. """Load and configure the diffusion pipeline via PipelineLoader.
This is called once at worker startup to load the model. This is called once at worker startup to load the model.
The specific pipeline class is determined by the auto-detected pipeline type. PipelineLoader handles:
1. Loading config via DiffusionModelConfig.from_pretrained()
2. Creating pipeline via AutoPipeline.from_config() (auto-detects type)
3. Loading weights with optional on-the-fly quantization
4. Post-load hooks (TeaCache setup, etc.)
5. torch.compile (if enabled)
6. Warmup inference
""" """
if self._initialized: if self._initialized:
logger.warning("Engine already initialized, skipping") logger.warning("Engine already initialized, skipping")
return return
logger.info( logger.info(
f"Initializing DiffusionEngine: pipeline={self._class_name}, " f"Initializing DiffusionEngine: model_path={self.config.model_path}"
f"model_path={self.config.model_path}"
) )
# Import visual_gen setup # Import TensorRT-LLM visual_gen components
from visual_gen import setup_configs from tensorrt_llm._torch.visual_gen import PipelineLoader
# Build configuration dict based on model type
dit_configs = self._build_dit_configs()
logger.info(f"dit_configs: {dit_configs}")
# Setup global configuration (required before pipeline loading)
setup_configs(**dit_configs)
# Dynamically import the pipeline class
logger.info(f"Importing pipeline from {self._module_path}.{self._class_name}")
module = importlib.import_module(self._module_path)
pipeline_class = getattr(module, self._class_name)
# Load the pipeline
# Convert torch_dtype string to actual torch dtype
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(self.config.torch_dtype, torch.bfloat16)
logger.info(
f"Loading pipeline from {self.config.model_path} with dtype={self.config.torch_dtype}"
)
self._pipeline = pipeline_class.from_pretrained(
self.config.model_path,
torch_dtype=torch_dtype,
**dit_configs,
)
# Move to target device # Build DiffusionArgs from DiffusionConfig
# NOTE: HuggingFace's from_pretrained() loads to CPU by default, diffusion_args = self._build_diffusion_args()
# so we must explicitly move to GPU for optimal performance. logger.info(f"DiffusionArgs: {diffusion_args}")
if self.device == "cuda":
logger.info("Moving pipeline to GPU...") # Use PipelineLoader for the full loading flow:
self._pipeline.to(self.device) # DiffusionArgs → DiffusionModelConfig → AutoPipeline → BasePipeline
logger.info("Pipeline moved to GPU successfully") loader = PipelineLoader(diffusion_args)
else: self._pipeline = loader.load()
logger.info("CPU offload enabled, pipeline stays on CPU")
self._initialized = True self._initialized = True
logger.info(f"DiffusionEngine initialization complete: {self._class_name}") logger.info(
f"DiffusionEngine initialization complete: "
f"{self._pipeline.__class__.__name__}"
)
def _build_dit_configs(self) -> dict[str, Any]: def _build_diffusion_args(self) -> "DiffusionArgs":
"""Build dit_configs dict from DiffusionConfig. """Build DiffusionArgs from DiffusionConfig.
Maps dynamo's DiffusionConfig fields to TensorRT-LLM's DiffusionArgs
structure with its nested sub-configs (PipelineConfig, AttentionConfig,
ParallelConfig, TeaCacheConfig, quant_config).
Returns: Returns:
Configuration dictionary for visual_gen's setup_configs. DiffusionArgs instance for PipelineLoader.
""" """
# Get torch_compile_models from auto-detected config overrides from tensorrt_llm._torch.visual_gen import (
# Each pipeline in PIPELINE_REGISTRY specifies its required settings DiffusionArgs,
torch_compile_models = self._config_overrides.get( ParallelConfig,
"torch_compile_models", "transformer" PipelineConfig,
TeaCacheConfig,
)
from tensorrt_llm._torch.visual_gen.config import AttentionConfig
# Build quant_config dict if quantization is requested
# DiffusionArgs accepts a dict in ModelOpt format and parses it via model_validator
quant_config: dict | None = None
if self.config.quant_algo:
quant_config = {
"quant_algo": self.config.quant_algo,
"dynamic": self.config.quant_dynamic,
}
args_kwargs: dict = dict(
checkpoint_path=self.config.model_path,
device=self.device,
dtype=self.config.torch_dtype,
skip_components=self.config.skip_components,
pipeline=PipelineConfig(
enable_torch_compile=not self.config.disable_torch_compile,
torch_compile_mode=self.config.torch_compile_mode,
enable_fullgraph=self.config.enable_fullgraph,
fuse_qkv=self.config.fuse_qkv,
enable_cuda_graph=self.config.enable_cuda_graph,
enable_layerwise_nvtx_marker=self.config.enable_layerwise_nvtx_marker,
warmup_steps=self.config.warmup_steps,
enable_offloading=self.config.enable_async_cpu_offload,
),
attention=AttentionConfig(
backend=self.config.attn_backend.upper(),
),
parallel=ParallelConfig(
dit_dp_size=self.config.dit_dp_size,
dit_tp_size=self.config.dit_tp_size,
dit_ulysses_size=self.config.dit_ulysses_size,
dit_ring_size=self.config.dit_ring_size,
dit_cfg_size=self.config.dit_cfg_size,
dit_fsdp_size=self.config.dit_fsdp_size,
),
teacache=TeaCacheConfig(
enable_teacache=self.config.enable_teacache,
use_ret_steps=self.config.teacache_use_ret_steps,
teacache_thresh=self.config.teacache_thresh,
),
) )
return { # Add optional fields
"pipeline": { if self.config.revision:
"enable_torch_compile": not self.config.disable_torch_compile, args_kwargs["revision"] = self.config.revision
"torch_compile_models": torch_compile_models, if quant_config is not None:
"torch_compile_mode": self.config.torch_compile_mode, args_kwargs["quant_config"] = quant_config
"fuse_qkv": True,
}, return DiffusionArgs(**args_kwargs)
"attn": {
"type": self.config.attn_type,
},
"linear": {
"type": self.config.linear_type,
"recipe": "dynamic",
},
"parallel": {
"disable_parallel_vae": False,
"parallel_vae_split_dim": "width",
"dit_dp_size": self.config.dit_dp_size,
"dit_tp_size": self.config.dit_tp_size,
"dit_ulysses_size": self.config.dit_ulysses_size,
"dit_ring_size": self.config.dit_ring_size,
"dit_cp_size": 1,
"dit_cfg_size": self.config.dit_cfg_size,
"dit_fsdp_size": self.config.dit_fsdp_size,
"t5_fsdp_size": 1,
},
"teacache": {
"enable_teacache": self.config.enable_teacache,
"use_ret_steps": self.config.teacache_use_ret_steps,
"teacache_thresh": self.config.teacache_thresh,
"ret_steps": 0,
"cutoff_steps": self.config.default_num_inference_steps,
},
}
def generate( def generate(
self, self,
...@@ -271,12 +210,15 @@ class DiffusionEngine: ...@@ -271,12 +210,15 @@ class DiffusionEngine:
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> np.ndarray: ) -> "MediaOutput":
"""Generate video/image frames from text prompt. """Generate video/image frames from text prompt.
This is a synchronous method that should be called from a thread pool This is a synchronous method that should be called from a thread pool
to avoid blocking the event loop. to avoid blocking the event loop.
The pipeline's infer() method handles the full generation flow:
prompt encoding, latent preparation, denoising loop, and VAE decoding.
Args: Args:
prompt: Text description of the content to generate. prompt: Text description of the content to generate.
negative_prompt: Text to avoid in the generation. negative_prompt: Text to avoid in the generation.
...@@ -288,8 +230,10 @@ class DiffusionEngine: ...@@ -288,8 +230,10 @@ class DiffusionEngine:
seed: Random seed for reproducibility. seed: Random seed for reproducibility.
Returns: Returns:
numpy array of shape (num_frames, height, width, 3) with uint8 values MediaOutput with model-specific fields populated:
for video, or (height, width, 3) for images. - .video: torch.Tensor (num_frames, H, W, 3) uint8 for video models
- .image: torch.Tensor (H, W, 3) uint8 for image models
- .audio: torch.Tensor for audio (if supported by model)
Raises: Raises:
RuntimeError: If engine not initialized or generation fails. RuntimeError: If engine not initialized or generation fails.
...@@ -302,41 +246,46 @@ class DiffusionEngine: ...@@ -302,41 +246,46 @@ class DiffusionEngine:
f"size={width}x{height}, frames={num_frames}, steps={num_inference_steps}" f"size={width}x{height}, frames={num_frames}, steps={num_inference_steps}"
) )
# Create generator for reproducibility # Use TRT-LLM's DiffusionRequest dataclass so that all defaults
# Device must match pipeline device (CPU if offload enabled, CUDA otherwise) # (including pipeline-specific fields like max_sequence_length,
generator = None # guidance_scale_2, boundary_ratio) are owned by TRT-LLM rather
if seed is not None: # than hardcoded here.
generator = torch.Generator(device=self.device).manual_seed(seed) from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest
# Run the pipeline req = DiffusionRequest(
with torch.no_grad(): request_id=0,
result = self._pipeline( prompt=prompt,
prompt=prompt, negative_prompt=negative_prompt,
negative_prompt=negative_prompt, height=height,
height=height, width=width,
width=width, num_frames=num_frames,
num_frames=num_frames, num_inference_steps=num_inference_steps,
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
guidance_scale=guidance_scale, seed=seed if seed is not None else random.randint(0, 2**32 - 1),
generator=generator, )
output_type="np", # Return numpy array
) # Run the pipeline — infer() wraps forward() with torch.no_grad()
output = self._pipeline.infer(req)
# result.frames[0] is numpy array (num_frames, height, width, 3) uint8 if output is not None:
frames = result.frames[0] if output.video is not None:
logger.info(f"Generated output with shape {frames.shape}") logger.info(f"Generated video output with shape {output.video.shape}")
elif output.image is not None:
logger.info(f"Generated image output with shape {output.image.shape}")
return frames return output
def cleanup(self) -> None: def cleanup(self) -> None:
"""Cleanup resources.""" """Cleanup resources."""
if self._pipeline is not None: if self._pipeline is not None:
if hasattr(self._pipeline, "cleanup"):
self._pipeline.cleanup()
del self._pipeline del self._pipeline
self._pipeline = None self._pipeline = None
self._initialized = False self._initialized = False
if self.device == "cuda": if self.device == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info(f"DiffusionEngine cleanup complete: {self._class_name}") logger.info("DiffusionEngine cleanup complete")
@property @property
def is_initialized(self) -> bool: def is_initialized(self) -> bool:
...@@ -345,8 +294,26 @@ class DiffusionEngine: ...@@ -345,8 +294,26 @@ class DiffusionEngine:
@property @property
def supported_modalities(self) -> list[str]: def supported_modalities(self) -> list[str]:
"""Get the modalities supported by this engine's model type.""" """Get the modalities supported by this engine's pipeline.
return self._supported_modalities
Uses an explicit mapping from pipeline class names to modalities
(see ``_PIPELINE_MODALITY_MAP``). The pipeline class is determined
at load time by AutoPipeline from model_index.json.
"""
if self._pipeline is None:
return [_DEFAULT_MODALITY.value]
class_name = self._pipeline.__class__.__name__
modality = _PIPELINE_MODALITY_MAP.get(class_name)
if modality is None:
logger.warning(
"Unknown pipeline class '%s' — defaulting to %s. "
"Please add it to _PIPELINE_MODALITY_MAP.",
class_name,
_DEFAULT_MODALITY.value,
)
modality = _DEFAULT_MODALITY
return [modality.value]
@property @property
def device(self) -> str: def device(self) -> str:
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
"""Video generation request handler for TensorRT-LLM backend. """Video generation request handler for TensorRT-LLM backend.
This handler processes video generation requests using diffusion models. This handler processes video generation requests using diffusion models.
It handles MediaOutput from TensorRT-LLM's visual_gen pipelines, which
can contain video, image, and/or audio tensors depending on the model.
""" """
import asyncio import asyncio
...@@ -32,9 +34,14 @@ logger = logging.getLogger(__name__) ...@@ -32,9 +34,14 @@ logger = logging.getLogger(__name__)
class VideoGenerationHandler(BaseGenerativeHandler): class VideoGenerationHandler(BaseGenerativeHandler):
"""Handler for video generation requests. """Handler for video generation requests.
This handler receives video generation requests, runs the diffusion This handler receives generation requests, runs the diffusion pipeline
pipeline via DiffusionEngine, encodes the output to MP4, and returns via DiffusionEngine, encodes the output to the appropriate media format,
the video URL or base64-encoded data. and returns the media URL or base64-encoded data.
Supports MediaOutput with:
- video: torch.Tensor (num_frames, H, W, 3) uint8 → encoded as MP4
- image: logged as unsupported (use an image handler instead)
- audio: logged (future: mux into MP4)
Inherits from BaseGenerativeHandler to share the common interface with Inherits from BaseGenerativeHandler to share the common interface with
LLM handlers. LLM handlers.
...@@ -59,8 +66,8 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -59,8 +66,8 @@ class VideoGenerationHandler(BaseGenerativeHandler):
) )
self.media_output_fs = get_fs(config.media_output_fs_url) self.media_output_fs = get_fs(config.media_output_fs_url)
self.media_output_http_url = config.media_output_http_url self.media_output_http_url = config.media_output_http_url
# Serialize pipeline access — visual_gen is not thread-safe (global # Serialize pipeline access — the diffusion pipeline is not thread-safe
# singleton configs, mutable instance state, unprotected CUDA graph cache). # (mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event # asyncio.Lock suspends waiting coroutines cooperatively so the event
# loop stays free for health checks and signal handling. # loop stays free for health checks and signal handling.
self._generate_lock = asyncio.Lock() self._generate_lock = asyncio.Lock()
...@@ -159,21 +166,26 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -159,21 +166,26 @@ class VideoGenerationHandler(BaseGenerativeHandler):
async def generate( async def generate(
self, request: dict[str, Any], context: Context self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]: ) -> AsyncGenerator[dict[str, Any], None]:
"""Generate video from request. """Generate video/image from request.
This is the main entry point called by Dynamo's endpoint.serve_endpoint(). This is the main entry point called by Dynamo's endpoint.serve_endpoint().
Handles MediaOutput from the pipeline:
- video tensor → MP4
- image tensor → unsupported (raises error)
- audio tensor → unsupported (raises error)
Args: Args:
request: Request dictionary with video generation parameters. request: Request dictionary with generation parameters.
context: Dynamo context for request tracking. context: Dynamo context for request tracking.
Yields: Yields:
Response dictionary with generated video data. Response dictionary with generated media data.
""" """
start_time = time.time() start_time = time.time()
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
logger.info(f"Received video generation request: {request_id}") logger.info(f"Received generation request: {request_id}")
try: try:
# Parse request # Parse request
...@@ -202,11 +214,11 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -202,11 +214,11 @@ class VideoGenerationHandler(BaseGenerativeHandler):
# Run generation in thread pool (blocking operation). # Run generation in thread pool (blocking operation).
# Lock ensures only one request uses the pipeline at a time. # Lock ensures only one request uses the pipeline at a time.
# TODO: Add cancellation support. This requires: # TODO: Add cancellation support. This requires:
# 1. visual_gen to expose a cancellation hook in the denoising loop # 1. The pipeline to expose a cancellation hook in the denoising loop
# 2. Passing a cancellation token/event to engine.generate() # 2. Passing a cancellation token/event to engine.generate()
# 3. Checking context.cancelled() and propagating to the pipeline # 3. Checking context.cancelled() and propagating to the pipeline
async with self._generate_lock: async with self._generate_lock:
frames = await asyncio.to_thread( output = await asyncio.to_thread(
self.engine.generate, self.engine.generate,
prompt=req.prompt, prompt=req.prompt,
negative_prompt=nvext.negative_prompt, negative_prompt=nvext.negative_prompt,
...@@ -218,15 +230,47 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -218,15 +230,47 @@ class VideoGenerationHandler(BaseGenerativeHandler):
seed=nvext.seed, seed=nvext.seed,
) )
if output is None:
raise RuntimeError("Pipeline returned no output (MediaOutput is None)")
# Determine output format # Determine output format
response_format = req.response_format or "url" response_format = req.response_format or "url"
fps = nvext.fps or self.config.default_fps fps = nvext.fps or self.config.default_fps
# Encode frames to MP4 bytes in memory # Encode media based on what the pipeline returned
video_bytes = await asyncio.to_thread(encode_to_mp4_bytes, frames, fps=fps) if output.video is not None:
# Video output: torch.Tensor (num_frames, H, W, 3) uint8 → MP4
frames_np = output.video.cpu().numpy()
logger.info(
f"Request {request_id}: encoding video output "
f"(shape={frames_np.shape}) to MP4 at {fps} fps"
)
video_bytes = await asyncio.to_thread(
encode_to_mp4_bytes, frames_np, fps=fps
)
elif output.image is not None:
raise RuntimeError(
"Pipeline returned image-only output, but this handler "
"only supports video. Use an image generation handler instead."
)
# Log audio if present (unsupported)
elif output.audio is not None:
raise RuntimeError(
"Pipeline returned audio-only output, but this handler "
"only supports video. Use an audio generation handler instead."
)
else:
raise RuntimeError(
"Pipeline returned MediaOutput with no video or image or audio data. "
f"MediaOutput fields: video={output.video is not None}, "
f"image={output.image is not None}, audio={output.audio is not None}"
)
# Return media via URL or base64
if response_format == "url": if response_format == "url":
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4" storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs( video_url = await upload_to_fs(
self.media_output_fs, self.media_output_fs,
...@@ -236,7 +280,6 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -236,7 +280,6 @@ class VideoGenerationHandler(BaseGenerativeHandler):
) )
video_data = VideoData(url=video_url) video_data = VideoData(url=video_url)
else: else:
# Encode to base64
b64_video = base64.b64encode(video_bytes).decode("utf-8") b64_video = base64.b64encode(video_bytes).decode("utf-8")
video_data = VideoData(b64_json=b64_video) video_data = VideoData(b64_json=b64_video)
......
...@@ -13,10 +13,12 @@ import asyncio ...@@ -13,10 +13,12 @@ import asyncio
import threading import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from types import SimpleNamespace
from typing import Optional from typing import Optional
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch
from dynamo.common.protocols.video_protocol import ( from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
...@@ -104,8 +106,11 @@ class TestDiffusionConfig: ...@@ -104,8 +106,11 @@ class TestDiffusionConfig:
# Optimization defaults # Optimization defaults
assert config.enable_teacache is False assert config.enable_teacache is False
assert config.attn_type == "default" assert config.attn_backend == "VANILLA"
assert config.linear_type == "default" assert config.quant_algo is None
assert config.enable_cuda_graph is False
assert config.warmup_steps == 1
assert config.fuse_qkv is True
# Parallelism defaults # Parallelism defaults
assert config.dit_dp_size == 1 assert config.dit_dp_size == 1
...@@ -532,10 +537,12 @@ class ConcurrencyTracker: ...@@ -532,10 +537,12 @@ class ConcurrencyTracker:
with self._lock: with self._lock:
self._active_count -= 1 self._active_count -= 1
# Return fake frames (shape: [num_frames, H, W, C]) # Return a mock MediaOutput with a video tensor
import numpy as np return SimpleNamespace(
video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
return np.zeros((4, 64, 64, 3), dtype=np.uint8) image=None,
audio=None,
)
class TestVideoHandlerConcurrency: class TestVideoHandlerConcurrency:
...@@ -660,16 +667,17 @@ class TestVideoHandlerResponseFormats: ...@@ -660,16 +667,17 @@ class TestVideoHandlerResponseFormats:
def _make_handler(self): def _make_handler(self):
"""Create a handler with mocked engine and fs.""" """Create a handler with mocked engine and fs."""
import numpy as np
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import ( from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler, VideoGenerationHandler,
) )
mock_engine = MagicMock() mock_output = SimpleNamespace(
mock_engine.generate = MagicMock( video=torch.zeros((4, 64, 64, 3), dtype=torch.uint8),
return_value=np.zeros((4, 64, 64, 3), dtype=np.uint8) image=None,
audio=None,
) )
mock_engine = MagicMock()
mock_engine.generate = MagicMock(return_value=mock_output)
config = DiffusionConfig( config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media", media_output_fs_url="file:///tmp/test_media",
......
...@@ -33,20 +33,19 @@ async def init_video_diffusion_worker( ...@@ -33,20 +33,19 @@ async def init_video_diffusion_worker(
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown. shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
""" """
# Check visual_gen availability early with a clear error message. # Check tensorrt_llm visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen # visual_gen is part of TensorRT-LLM (tensorrt_llm._torch.visual_gen).
# branch — not yet in any release. Without this check, users would get a cryptic # Without this check, users would get a cryptic ImportError deep inside
# ImportError deep inside DiffusionEngine.initialize(). # DiffusionEngine.initialize().
try: try:
import visual_gen # noqa: F401 import tensorrt_llm._torch.visual_gen # noqa: F401
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Video diffusion requires the 'visual_gen' package from TensorRT-LLM's " "Video diffusion requires TensorRT-LLM with visual_gen support.\n"
"feat/visual_gen branch. Install with:\n" "The visual_gen module is at tensorrt_llm._torch.visual_gen.\n"
" git clone https://github.com/NVIDIA/TensorRT-LLM.git\n" "Install TensorRT-LLM with AIGV support:\n"
" cd TensorRT-LLM && git checkout feat/visual_gen\n" " pip install tensorrt_llm\n"
" cd tensorrt_llm/visual_gen && pip install -e .\n" "See: https://github.com/NVIDIA/TensorRT-LLM"
"See: https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen"
) from None ) from None
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
...@@ -55,6 +54,13 @@ async def init_video_diffusion_worker( ...@@ -55,6 +54,13 @@ async def init_video_diffusion_worker(
logging.info(f"Initializing video diffusion worker with config: {config}") logging.info(f"Initializing video diffusion worker with config: {config}")
# Parse skip_components from comma-separated string to list
skip_components = (
[c.strip() for c in config.skip_components.split(",") if c.strip()]
if config.skip_components
else []
)
# Build DiffusionConfig from the main Config # Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig( diffusion_config = DiffusionConfig(
namespace=config.namespace, namespace=config.namespace,
...@@ -65,6 +71,8 @@ async def init_video_diffusion_worker( ...@@ -65,6 +71,8 @@ async def init_video_diffusion_worker(
event_plane=config.event_plane, event_plane=config.event_plane,
model_path=config.model, model_path=config.model,
served_model_name=config.served_model_name, served_model_name=config.served_model_name,
torch_dtype=config.torch_dtype,
revision=config.revision,
media_output_fs_url=config.media_output_fs_url, media_output_fs_url=config.media_output_fs_url,
media_output_http_url=config.media_output_http_url, media_output_http_url=config.media_output_http_url,
default_height=config.default_height, default_height=config.default_height,
...@@ -72,19 +80,34 @@ async def init_video_diffusion_worker( ...@@ -72,19 +80,34 @@ async def init_video_diffusion_worker(
default_num_frames=config.default_num_frames, default_num_frames=config.default_num_frames,
default_num_inference_steps=config.default_num_inference_steps, default_num_inference_steps=config.default_num_inference_steps,
default_guidance_scale=config.default_guidance_scale, default_guidance_scale=config.default_guidance_scale,
enable_teacache=config.enable_teacache, # Pipeline optimization
teacache_thresh=config.teacache_thresh,
attn_type=config.attn_type,
linear_type=config.linear_type,
disable_torch_compile=config.disable_torch_compile, disable_torch_compile=config.disable_torch_compile,
torch_compile_mode=config.torch_compile_mode, torch_compile_mode=config.torch_compile_mode,
enable_fullgraph=config.enable_fullgraph,
fuse_qkv=config.fuse_qkv,
enable_cuda_graph=config.enable_cuda_graph,
enable_layerwise_nvtx_marker=config.enable_layerwise_nvtx_marker,
warmup_steps=config.warmup_steps,
# Attention
attn_backend=config.attn_backend,
# Quantization
quant_algo=config.quant_algo,
quant_dynamic=config.quant_dynamic,
# TeaCache
enable_teacache=config.enable_teacache,
teacache_use_ret_steps=config.teacache_use_ret_steps,
teacache_thresh=config.teacache_thresh,
# Parallelism
dit_dp_size=config.dit_dp_size, dit_dp_size=config.dit_dp_size,
dit_tp_size=config.dit_tp_size, dit_tp_size=config.dit_tp_size,
dit_ulysses_size=config.dit_ulysses_size, dit_ulysses_size=config.dit_ulysses_size,
dit_ring_size=config.dit_ring_size, dit_ring_size=config.dit_ring_size,
dit_cfg_size=config.dit_cfg_size, dit_cfg_size=config.dit_cfg_size,
dit_fsdp_size=config.dit_fsdp_size, dit_fsdp_size=config.dit_fsdp_size,
# Offloading
enable_async_cpu_offload=config.enable_async_cpu_offload, enable_async_cpu_offload=config.enable_async_cpu_offload,
# Component loading
skip_components=skip_components,
) )
# Get the endpoint from the runtime # Get the endpoint from the runtime
......
...@@ -216,11 +216,10 @@ Dynamo supports video generation using diffusion models through the `--modality ...@@ -216,11 +216,10 @@ Dynamo supports video generation using diffusion models through the `--modality
### Requirements ### Requirements
- **visual_gen**: Part of TensorRT-LLM, located at `tensorrt_llm/visual_gen/`. Currently available **only** on the [`feat/visual_gen`](https://github.com/NVIDIA/TensorRT-LLM/tree/feat/visual_gen/tensorrt_llm/visual_gen) branch (not yet merged to main or any release). Install from source: - **TensorRT-LLM with visual_gen**: The `visual_gen` module is part of TensorRT-LLM (`tensorrt_llm._torch.visual_gen`). Install TensorRT-LLM following the [official instructions](https://github.com/NVIDIA/TensorRT-LLM#installation).
- **imageio with ffmpeg**: Required for encoding generated frames to MP4 video:
```bash ```bash
git clone https://github.com/NVIDIA/TensorRT-LLM.git pip install imageio[ffmpeg]
cd TensorRT-LLM && git checkout feat/visual_gen
cd tensorrt_llm/visual_gen && pip install -e .
``` ```
- **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version. - **dynamo-runtime with video API**: The Dynamo runtime must include `ModelType.Videos` support. Ensure you're using a compatible version.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment