Commit a1ebc651 authored by xuwx1's avatar xuwx1
Browse files

updata lightx2v

parent 5a4db490
Pipeline #3149 canceled with stages
# LightX2V Usage Examples
This document introduces how to use LightX2V for video generation, including basic usage and advanced configurations.
## 📋 Table of Contents
- [Environment Setup](#environment-setup)
- [Basic Usage Examples](#basic-usage-examples)
- [Model Path Configuration](#model-path-configuration)
- [Creating Generator](#creating-generator)
- [Advanced Configurations](#advanced-configurations)
- [Parameter Offloading](#parameter-offloading)
- [Model Quantization](#model-quantization)
- [Parallel Inference](#parallel-inference)
- [Feature Caching](#feature-caching)
- [Lightweight VAE](#lightweight-vae)
## 🔧 Environment Setup
Please refer to the main project's [Quick Start Guide](../docs/EN/source/getting_started/quickstart.md) for environment setup.
## 🚀 Basic Usage Examples
A minimal code example can be found in `examples/wan_t2v.py`:
```python
from lightx2v import LightX2VPipeline
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-T2V-14B",
model_cls="wan2.1",
task="t2v",
)
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
sample_shift=5.0,
)
seed = 42
prompt = "Your prompt here"
negative_prompt = ""
save_result_path="/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
```
## 📁 Model Path Configuration
### Basic Configuration
Pass the model path to `LightX2VPipeline`:
```python
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe", # For wan2.1, use "wan2.1"
task="i2v",
)
```
### Specifying Multiple Model Weight Versions
When there are multiple versions of bf16 precision DIT model safetensors files in the `model_path` directory, you need to use the following parameters to specify which weights to use:
- **`dit_original_ckpt`**: Used to specify the original DIT weight path for models like wan2.1 and hunyuan15
- **`low_noise_original_ckpt`**: Used to specify the low noise branch weight path for wan2.2 models
- **`high_noise_original_ckpt`**: Used to specify the high noise branch weight path for wan2.2 models
**Usage Example:**
```python
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe",
task="i2v",
low_noise_original_ckpt="/path/to/low_noise_model.safetensors",
high_noise_original_ckpt="/path/to/high_noise_model.safetensors",
)
```
## 🎛️ Creating Generator
### Loading from Configuration File
The generator can be loaded directly from a JSON configuration file. Configuration files are located in the `configs` directory:
```python
pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
```
### Creating Generator Manually
You can also create the generator manually and configure multiple parameters:
```python
pipe.create_generator(
attn_mode="flash_attn2", # Options: flash_attn2, flash_attn3, sage_attn2, sage_attn3 (B-architecture GPUs)
infer_steps=50, # Number of inference steps
num_frames=81, # Number of video frames
height=480, # Video height
width=832, # Video width
guidance_scale=5.0, # CFG guidance strength (CFG disabled when =1)
sample_shift=5.0, # Sample shift
fps=16, # Frame rate
aspect_ratio="16:9", # Aspect ratio
boundary=0.900, # Boundary value
boundary_step_index=2, # Boundary step index
denoising_step_list=[1000, 750, 500, 250], # Denoising step list
)
```
**Parameter Description:**
- **Resolution**: Specified via `height` and `width`
- **CFG**: Specified via `guidance_scale` (set to 1 to disable CFG)
- **FPS**: Specified via `fps`
- **Video Length**: Specified via `num_frames`
- **Inference Steps**: Specified via `infer_steps`
- **Sample Shift**: Specified via `sample_shift`
- **Attention Mode**: Specified via `attn_mode`, options include `flash_attn2`, `flash_attn3`, `sage_attn2`, `sage_attn3` (for B-architecture GPUs)
## ⚙️ Advanced Configurations
**⚠️ Important: When manually creating a generator, you can configure some advanced options. All advanced configurations must be specified before `create_generator()`, otherwise they will not take effect!**
### Parameter Offloading
Significantly reduces memory usage with almost no impact on inference speed. Suitable for RTX 30/40/50 series GPUs.
```python
pipe.enable_offload(
cpu_offload=True, # Enable CPU offloading
offload_granularity="block", # Offload granularity: "block" or "phase"
text_encoder_offload=False, # Whether to offload text encoder
image_encoder_offload=False, # Whether to offload image encoder
vae_offload=False, # Whether to offload VAE
)
```
**Notes:**
- For Wan models, `offload_granularity` supports both `"block"` and `"phase"`
- For HunyuanVideo-1.5, only `"block"` is currently supported
### Model Quantization
Quantization can significantly reduce memory usage and accelerate inference.
```python
pipe.enable_quantize(
dit_quantized=False, # Whether to use quantized DIT model
text_encoder_quantized=False, # Whether to use quantized text encoder
image_encoder_quantized=False, # Whether to use quantized image encoder
dit_quantized_ckpt=None, # DIT quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files)
low_noise_quantized_ckpt=None, # Wan2.2 low noise branch quantized weight path
high_noise_quantized_ckpt=None, # Wan2.2 high noise branch quantized weight path
text_encoder_quantized_ckpt=None, # Text encoder quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files)
image_encoder_quantized_ckpt=None, # Image encoder quantized weight path (required when model_path doesn't contain quantized weights or has multiple weight files)
quant_scheme="fp8-sgl", # Quantization scheme
)
```
**Parameter Description:**
- **`dit_quantized_ckpt`**: When the `model_path` directory doesn't contain quantized weights, or has multiple weight files, you need to specify the specific DIT quantized weight path
- **`text_encoder_quantized_ckpt`** and **`image_encoder_quantized_ckpt`**: Similarly, used to specify encoder quantized weight paths
- **`low_noise_quantized_ckpt`** and **`high_noise_quantized_ckpt`**: Used to specify dual-branch quantized weights for Wan2.2 models
**Quantized Model Downloads:**
- **Wan-2.1 Quantized Models**: Download from [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
- **Wan-2.2 Quantized Models**: Download from [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
- **HunyuanVideo-1.5 Quantized Models**: Download from [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models)
- `hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors` is the quantized weight for the text encoder
**Usage Examples:**
```python
# HunyuanVideo-1.5 Quantization Example
pipe.enable_quantize(
quant_scheme='fp8-sgl',
dit_quantized=True,
dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors",
text_encoder_quantized=True,
image_encoder_quantized=False,
text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors",
)
# Wan2.1 Quantization Example
pipe.enable_quantize(
dit_quantized=True,
dit_quantized_ckpt="/path/to/wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors",
)
# Wan2.2 Quantization Example
pipe.enable_quantize(
dit_quantized=True,
low_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
high_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step_1030.safetensors",
)
```
**Quantization Scheme Reference:** For detailed information, please refer to the [Quantization Documentation](../docs/EN/source/method_tutorials/quantization.md)
### Parallel Inference
Supports multi-GPU parallel inference. Requires running with `torchrun`:
```python
pipe.enable_parallel(
seq_p_size=4, # Sequence parallel size
seq_p_attn_type="ulysses", # Sequence parallel attention type
)
```
**Running Method:**
```bash
torchrun --nproc_per_node=4 your_script.py
```
### Feature Caching
You can specify the cache method as Mag or Tea, using MagCache and TeaCache methods:
```python
pipe.enable_cache(
cache_method='Tea', # Cache method: 'Tea' or 'Mag'
coefficients=[-3.08907507e+04, 1.67786188e+04, -3.19178643e+03,
2.60740519e+02, -8.19205881e+00, 1.07913775e-01], # Coefficients
teacache_thresh=0.15, # TeaCache threshold
)
```
**Coefficient Reference:** Refer to configuration files in `configs/caching` or `configs/hunyuan_video_15/cache` directories
### Lightweight VAE
Using lightweight VAE can accelerate decoding and reduce memory usage.
```python
pipe.enable_lightvae(
use_lightvae=False, # Whether to use LightVAE
use_tae=False, # Whether to use LightTAE
vae_path=None, # Path to LightVAE
tae_path=None, # Path to LightTAE
)
```
**Support Status:**
- **LightVAE**: Currently only supports wan2.1, wan2.2 moe
- **LightTAE**: Currently only supports wan2.1, wan2.2-ti2v, wan2.2 moe, HunyuanVideo-1.5
**Model Downloads:** Lightweight VAE models can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
- LightVAE for Wan-2.1: [lightvaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lightvaew2_1.safetensors)
- LightTAE for Wan-2.1: [lighttaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_1.safetensors)
- LightTAE for Wan-2.2-ti2v: [lighttaew2_2.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_2.safetensors)
- LightTAE for HunyuanVideo-1.5: [lighttaehy1_5.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)
**Usage Example:**
```python
# Using LightTAE for HunyuanVideo-1.5
pipe.enable_lightvae(
use_tae=True,
tae_path="/path/to/lighttaehy1_5.safetensors",
use_lightvae=False,
vae_path=None
)
```
## 📚 More Resources
- [Full Documentation](https://lightx2v-en.readthedocs.io/en/latest/)
- [GitHub Repository](https://github.com/ModelTC/LightX2V)
- [HuggingFace Model Hub](https://huggingface.co/lightx2v)
# LightX2V 使用示例
本文档介绍如何使用 LightX2V 进行视频生成,包括基础使用和进阶配置。
## 📋 目录
- [环境安装](#环境安装)
- [基础运行示例](#基础运行示例)
- [模型路径配置](#模型路径配置)
- [创建生成器](#创建生成器)
- [进阶配置](#进阶配置)
- [参数卸载 (Offload)](#参数卸载-offload)
- [模型量化 (Quantization)](#模型量化-quantization)
- [并行推理 (Parallel Inference)](#并行推理-parallel-inference)
- [特征缓存 (Cache)](#特征缓存-cache)
- [轻量 VAE (Light VAE)](#轻量-vae-light-vae)
## 🔧 环境安装
请参考主项目的[快速入门文档](../docs/ZH_CN/source/getting_started/quickstart.md)进行环境安装。
## 🚀 基础运行示例
最小化代码示例可参考 `examples/wan_t2v.py`
```python
from lightx2v import LightX2VPipeline
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-T2V-14B",
model_cls="wan2.1",
task="t2v",
)
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
sample_shift=5.0,
)
seed = 42
prompt = "Your prompt here"
negative_prompt = ""
save_result_path="/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
```
## 📁 模型路径配置
### 基础配置
将模型路径传入 `LightX2VPipeline`
```python
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe", # 对于 wan2.1,使用 "wan2.1"
task="i2v",
)
```
### 多版本模型权重指定
`model_path` 目录下存在多个不同版本的 bf16 精度 DIT 模型 safetensors 文件时,需要使用以下参数指定具体使用哪个权重:
- **`dit_original_ckpt`**: 用于指定 wan2.1 和 hunyuan15 等模型的原始 DIT 权重路径
- **`low_noise_original_ckpt`**: 用于指定 wan2.2 模型的低噪声分支权重路径
- **`high_noise_original_ckpt`**: 用于指定 wan2.2 模型的高噪声分支权重路径
**使用示例:**
```python
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe",
task="i2v",
low_noise_original_ckpt="/path/to/low_noise_model.safetensors",
high_noise_original_ckpt="/path/to/high_noise_model.safetensors",
)
```
## 🎛️ 创建生成器
### 从配置文件加载
生成器可以从 JSON 配置文件直接加载,配置文件位于 `configs` 目录:
```python
pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
```
### 手动创建生成器
也可以手动创建生成器,并配置多个参数:
```python
pipe.create_generator(
attn_mode="flash_attn2", # 可选: flash_attn2, flash_attn3, sage_attn2, sage_attn3 (B架构显卡适用)
infer_steps=50, # 推理步数
num_frames=81, # 视频帧数
height=480, # 视频高度
width=832, # 视频宽度
guidance_scale=5.0, # CFG引导强度 (=1时弃用CFG)
sample_shift=5.0, # 采样偏移
fps=16, # 帧率
aspect_ratio="16:9", # 宽高比
boundary=0.900, # 边界值
boundary_step_index=2, # 边界步索引
denoising_step_list=[1000, 750, 500, 250], # 去噪步列表
)
```
**参数说明:**
- **分辨率**: 通过 `height``width` 指定
- **CFG**: 通过 `guidance_scale` 指定(设置为 1 时禁用 CFG)
- **FPS**: 通过 `fps` 指定帧率
- **视频长度**: 通过 `num_frames` 指定帧数
- **推理步数**: 通过 `infer_steps` 指定
- **采样偏移**: 通过 `sample_shift` 指定
- **注意力模式**: 通过 `attn_mode` 指定,可选 `flash_attn2`, `flash_attn3`, `sage_attn2`, `sage_attn3`(B架构显卡适用)
## ⚙️ 进阶配置
**⚠️ 重要提示:手动创建生成器时,可以配置一些进阶选项,所有进阶配置必须在 `create_generator()` 之前指定,否则会失效!**
### 参数卸载 (Offload)
显著降低显存占用,几乎不影响推理速度,适用于 RTX 30/40/50 系列显卡。
```python
pipe.enable_offload(
cpu_offload=True, # 启用 CPU 卸载
offload_granularity="block", # 卸载粒度: "block" 或 "phase"
text_encoder_offload=False, # 文本编码器是否卸载
image_encoder_offload=False, # 图像编码器是否卸载
vae_offload=False, # VAE 是否卸载
)
```
**说明:**
- 对于 Wan 模型,`offload_granularity` 支持 `"block"``"phase"`
- 对于 HunyuanVideo-1.5,目前只支持 `"block"`
### 模型量化 (Quantization)
量化可以显著降低显存占用并加速推理。
```python
pipe.enable_quantize(
dit_quantized=False, # 是否使用量化的 DIT 模型
text_encoder_quantized=False, # 是否使用量化的文本编码器
image_encoder_quantized=False, # 是否使用量化的图像编码器
dit_quantized_ckpt=None, # DIT 量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定)
low_noise_quantized_ckpt=None, # Wan2.2 低噪声分支量化权重路径
high_noise_quantized_ckpt=None, # Wan2.2 高噪声分支量化权重路径
text_encoder_quantized_ckpt=None, # 文本编码器量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定)
image_encoder_quantized_ckpt=None, # 图像编码器量化权重路径(当 model_path 下没有量化权重或存在多个权重时需要指定)
quant_scheme="fp8-sgl", # 量化方案
)
```
**参数说明:**
- **`dit_quantized_ckpt`**: 当 `model_path` 目录下没有量化权重,或存在多个权重文件时,需要指定具体的 DIT 量化权重路径
- **`text_encoder_quantized_ckpt`****`image_encoder_quantized_ckpt`**: 类似地,用于指定编码器的量化权重路径
- **`low_noise_quantized_ckpt`****`high_noise_quantized_ckpt`**: 用于指定 Wan2.2 模型的双分支量化权重
**量化模型下载:**
- **Wan-2.1 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models) 下载
- **Wan-2.2 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models) 下载
- **HunyuanVideo-1.5 量化模型**: 从 [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models) 下载
- `hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors` 是文本编码器的量化权重
**使用示例:**
```python
# HunyuanVideo-1.5 量化示例
pipe.enable_quantize(
quant_scheme='fp8-sgl',
dit_quantized=True,
dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors",
text_encoder_quantized=True,
image_encoder_quantized=False,
text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors",
)
# Wan2.1 量化示例
pipe.enable_quantize(
dit_quantized=True,
dit_quantized_ckpt="/path/to/wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_4step.safetensors",
)
# Wan2.2 量化示例
pipe.enable_quantize(
dit_quantized=True,
low_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors",
high_noise_quantized_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step_1030.safetensors",
)
```
**量化方案参考:** 详细说明请参考 [量化文档](../docs/ZH_CN/source/method_tutorials/quantization.md)
### 并行推理 (Parallel Inference)
支持多 GPU 并行推理,需要使用 `torchrun` 运行:
```python
pipe.enable_parallel(
seq_p_size=4, # 序列并行大小
seq_p_attn_type="ulysses", # 序列并行注意力类型
)
```
**运行方式:**
```bash
torchrun --nproc_per_node=4 your_script.py
```
### 特征缓存 (Cache)
可以指定缓存方法为 Mag 或 Tea,使用 MagCache 和 TeaCache 方法:
```python
pipe.enable_cache(
cache_method='Tea', # 缓存方法: 'Tea' 或 'Mag'
coefficients=[-3.08907507e+04, 1.67786188e+04, -3.19178643e+03,
2.60740519e+02, -8.19205881e+00, 1.07913775e-01], # 系数
teacache_thresh=0.15, # TeaCache 阈值
)
```
**系数参考:** 可参考 `configs/caching``configs/hunyuan_video_15/cache` 目录下的配置文件
### 轻量 VAE (Light VAE)
使用轻量 VAE 可以加速解码并降低显存占用。
```python
pipe.enable_lightvae(
use_lightvae=False, # 是否使用 LightVAE
use_tae=False, # 是否使用 LightTAE
vae_path=None, # LightVAE 的路径
tae_path=None, # LightTAE 的路径
)
```
**支持情况:**
- **LightVAE**: 目前只支持 wan2.1、wan2.2 moe
- **LightTAE**: 目前只支持 wan2.1、wan2.2-ti2v、wan2.2 moe、HunyuanVideo-1.5
**模型下载:** 轻量 VAE 模型可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载
- Wan-2.1 的 LightVAE: [lightvaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lightvaew2_1.safetensors)
- Wan-2.1 的 LightTAE: [lighttaew2_1.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_1.safetensors)
- Wan-2.2-ti2v 的 LightTAE: [lighttaew2_2.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaew2_2.safetensors)
- HunyuanVideo-1.5 的 LightTAE: [lighttaehy1_5.safetensors](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)
**使用示例:**
```python
# 使用 HunyuanVideo-1.5 的 LightTAE
pipe.enable_lightvae(
use_tae=True,
tae_path="/path/to/lighttaehy1_5.safetensors",
use_lightvae=False,
vae_path=None
)
```
## 📚 更多资源
- [完整文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)
- [GitHub 仓库](https://github.com/ModelTC/LightX2V)
- [HuggingFace 模型库](https://huggingface.co/lightx2v)
"""
HunyuanVideo-1.5 image-to-video generation example with quantization.
This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for I2V generation,
including quantized model usage for reduced memory consumption.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for HunyuanVideo-1.5 I2V task
pipe = LightX2VPipeline(
model_path="/path/to/ckpts/hunyuanvideo-1.5/",
model_cls="hunyuan_video_1.5",
transformer_model_name="720p_i2v",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_i2v_720p.json")
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Enable quantization for reduced memory usage
# Quantized models can be downloaded from: https://huggingface.co/lightx2v/Hy1.5-Quantized-Models
pipe.enable_quantize(
quant_scheme="fp8-sgl",
dit_quantized=True,
dit_quantized_ckpt="/path/to/hy15_720p_i2v_fp8_e4m3_lightx2v.safetensors",
text_encoder_quantized=True,
image_encoder_quantized=False,
text_encoder_quantized_ckpt="/path/to/hy15_qwen25vl_llm_encoder_fp8_e4m3_lightx2v.safetensors",
)
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
num_frames=121,
guidance_scale=6.0,
sample_shift=7.0,
fps=24,
)
# Generation parameters
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = ""
save_result_path = "/path/to/save_results/output2.mp4"
# Generate video
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
HunyuanVideo-1.5 text-to-video generation example.
This example demonstrates how to use LightX2V with HunyuanVideo-1.5 model for T2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for HunyuanVideo-1.5
pipe = LightX2VPipeline(
model_path="/path/to/ckpts/hunyuanvideo-1.5/",
model_cls="hunyuan_video_1.5",
transformer_model_name="720p_t2v",
task="t2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_t2v_720p.json")
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Use lighttae
pipe.enable_lightvae(
use_tae=True,
tae_path="/path/to/lighttaehy1_5.safetensors",
use_lightvae=False,
vae_path=None,
)
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
num_frames=121,
guidance_scale=6.0,
sample_shift=9.0,
aspect_ratio="16:9",
fps=24,
)
# Generation parameters
seed = 123
prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
negative_prompt = ""
save_result_path = "/path/to/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
HunyuanVideo-1.5 text-to-video generation example.
This example demonstrates how to use LightX2V with HunyuanVideo-1.5 4-step distilled model for T2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for HunyuanVideo-1.5
pipe = LightX2VPipeline(
model_path="/path/to/ckpts/hunyuanvideo-1.5/",
model_cls="hunyuan_video_1.5",
transformer_model_name="480p_t2v",
task="t2v",
# 4-step distilled model ckpt
dit_original_ckpt="/path/to/hy1.5_t2v_480p_lightx2v_4step.safetensors",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/hunyuan_video_15/hunyuan_video_t2v_720p.json")
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Use lighttae
pipe.enable_lightvae(
use_tae=True,
tae_path="/path/to/lighttaehy1_5.safetensors",
use_lightvae=False,
vae_path=None,
)
# Create generator with specified parameters
pipe.create_generator(attn_mode="sage_attn2", infer_steps=4, num_frames=81, guidance_scale=1, sample_shift=9.0, aspect_ratio="16:9", fps=16, denoising_step_list=[1000, 750, 500, 250])
# Generation parameters
seed = 123
prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
negative_prompt = ""
save_result_path = "/data/nvme0/gushiqiao/LightX2V/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.2 animate video generation example.
This example demonstrates how to use LightX2V with Wan2.2 model for animate video generation.
First, run preprocessing:
1. Set up environment: pip install -r ../requirements_animate.txt
2. For animate mode:
python ../tools/preprocess/preprocess_data.py \
--ckpt_path /path/to/Wan2.1-FLF2V-14B-720P/process_checkpoint \
--video_path /path/to/video \
--refer_path /path/to/ref_img \
--save_path ../save_results/animate/process_results \
--resolution_area 1280 720 \
--retarget_flag
3. For replace mode:
python ../tools/preprocess/preprocess_data.py \
--ckpt_path /path/to/Wan2.1-FLF2V-14B-720P/process_checkpoint \
--video_path /path/to/video \
--refer_path /path/to/ref_img \
--save_path ../save_results/replace/process_results \
--resolution_area 1280 720 \
--iterations 3 \
--k 7 \
--w_len 1 \
--h_len 1 \
--replace_flag
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for animate task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-FLF2V-14B-720P",
model_cls="wan2.2_animate",
task="animate",
)
pipe.replace_flag = True # Set to True for replace mode, False for animate mode
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan/wan_animate_replace.json"
# )
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=20,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=77,
guidance_scale=1,
sample_shift=5.0,
fps=30,
)
seed = 42
prompt = "视频中的人在做动作"
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
src_pose_path = "../save_results/animate/process_results/src_pose.mp4"
src_face_path = "../save_results/animate/process_results/src_face.mp4"
src_ref_images = "../save_results/animate/process_results/src_ref.png"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
src_pose_path=src_pose_path,
src_face_path=src_face_path,
src_ref_images=src_ref_images,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 first-last-frame-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.1 model for FLF2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for FLF2V task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-FLF2V-14B-720P",
model_cls="wan2.1",
task="flf2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan/wan_flf2v.json"
# )
# Optional: enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
# pipe.enable_offload(
# cpu_offload=True,
# offload_granularity="block",
# text_encoder_offload=True,
# image_encoder_offload=False,
# vae_offload=False,
# )
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=5,
sample_shift=5.0,
)
seed = 42
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
image_path = "../assets/inputs/imgs/flf2v_input_first_frame-fs8.png"
last_frame_path = "../assets/inputs/imgs/flf2v_input_last_frame-fs8.png"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
image_path=image_path,
last_frame_path=last_frame_path,
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.2 image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 I2V task
# For wan2.1, use model_cls="wan2.1"
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan22/wan_moe_i2v.json"
# )
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For Wan models, supports both "block" and "phase"
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
sample_shift=5.0,
)
# Generation parameters
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
image_path = "/path/to/img_0.jpg"
save_result_path = "/path/to/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.2 distilled model image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 distilled model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 distilled I2V task
# For wan2.1, use model_cls="wan2.1_distill"
pipe = LightX2VPipeline(
model_path="/path/to/wan2.2/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe_distill",
task="i2v",
# Distilled weights: For wan2.1, only need to specify dit_original_ckpt="/path/to/wan2.1_i2v_720p_lightx2v_4step.safetensors"
low_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
high_noise_original_ckpt="/path/to/wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan22/wan_moe_i2v_distill.json"
# )
# Enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block",
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1,
sample_shift=5.0,
)
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
image_path = "/path/to/img_0.jpg"
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.2 distilled model with LoRA image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 distilled model and LoRA for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 distilled I2V task with LoRA
# For wan2.1, use model_cls="wan2.1_distill"
pipe = LightX2VPipeline(
model_path="/path/to/wan2.2/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe_distill",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan22/wan_moe_i2v_distill_with_lora.json"
# )
# Enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block",
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Load distilled LoRA weights
pipe.enable_lora(
[
{"name": "high_noise_model", "path": "/path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0},
{"name": "low_noise_model", "path": "/path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "strength": 1.0},
]
)
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=4,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=1,
sample_shift=5.0,
)
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
image_path = "/path/to/img_0.jpg"
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 text-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.1 model for T2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.1 T2V task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-T2V-14B",
model_cls="wan2.1",
task="t2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(config_json="../configs/wan/wan_t2v.json")
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=50,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=5.0,
sample_shift=5.0,
)
seed = 42
prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
"""
Wan2.1 VACE (Video Animate Character Exchange) generation example.
This example demonstrates how to use LightX2V with Wan2.1 VACE model for character exchange in videos.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for VACE task
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.1-VACE-1.3B",
src_ref_images="../assets/inputs/imgs/girl.png,../assets/inputs/imgs/snake.png",
model_cls="wan2.1_vace",
task="vace",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="../configs/wan/wan_vace.json"
# )
# Optional: enable offloading to significantly reduce VRAM usage
# Suitable for RTX 30/40/50 consumer GPUs
# pipe.enable_offload(
# cpu_offload=True,
# offload_granularity="block",
# text_encoder_offload=True,
# image_encoder_offload=False,
# vae_offload=False,
# )
# Create generator with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=5,
sample_shift=16,
)
seed = 42
prompt = "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_result_path = "/path/to/save_results/output.mp4"
pipe.generate(
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
__version__ = "0.1.0"
__author__ = "LightX2V Contributors"
__license__ = "Apache 2.0"
import lightx2v_platform.set_ai_device
from lightx2v import common, deploy, models, utils
from lightx2v.pipeline import LightX2VPipeline
__all__ = [
"__version__",
"__author__",
"__license__",
"models",
"common",
"deploy",
"utils",
"LightX2VPipeline",
]
class WeightModule:
def __init__(self):
self._modules = {}
self._parameters = {}
def is_empty(self):
return len(self._modules) == 0 and len(self._parameters) == 0
def add_module(self, name, module):
self._modules[name] = module
setattr(self, name, module)
def register_parameter(self, name, param):
self._parameters[name] = param
setattr(self, name, param)
def load(self, weight_dict):
for _, module in self._modules.items():
if hasattr(module, "load"):
module.load(weight_dict)
for _, parameter in self._parameters.items():
if hasattr(parameter, "load"):
parameter.load(weight_dict)
def state_dict(self, destination=None):
if destination is None:
destination = {}
for _, param in self._parameters.items():
if param is not None:
param.state_dict(destination)
for _, module in self._modules.items():
if module is not None:
module.state_dict(destination)
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if destination is None:
destination = {}
for _, param in self._parameters.items():
if param is not None:
param.load_state_dict(destination, block_index, adapter_block_index)
for _, module in self._modules.items():
if module is not None:
module.load_state_dict(destination, block_index, adapter_block_index)
return destination
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
for _, param in self._parameters.items():
if param is not None:
param.load_state_dict_from_disk(block_index, adapter_block_index)
for _, module in self._modules.items():
if module is not None:
module.load_state_dict_from_disk(block_index, adapter_block_index)
def named_parameters(self, prefix=""):
for name, param in self._parameters.items():
if param is not None:
yield prefix + name, param
for name, module in self._modules.items():
if module is not None:
yield from module.named_parameters(prefix + name + ".")
def to_cpu(self):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cpu"):
self._parameters[name] = param.cpu()
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()
def to_cuda(self):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.cuda()
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda()
def to_cpu_async(self):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cpu"):
self._parameters[name] = param.cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=True)
def to_cuda_async(self):
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.cuda(non_blocking=True)
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if isinstance(module, WeightModuleList):
for i in range(len(module)):
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=True)
class WeightModuleList(WeightModule):
def __init__(self, modules=None):
super().__init__()
self._list = []
if modules is not None:
for idx, module in enumerate(modules):
self.append(module)
def append(self, module):
idx = len(self._list)
self._list.append(module)
self.add_module(str(idx), module)
def __getitem__(self, idx):
return self._list[idx]
def __setitem__(self, idx, module):
self._list[idx] = module
self.add_module(str(idx), module)
def __len__(self):
return len(self._list)
def __iter__(self):
return iter(self._list)
from concurrent.futures import ThreadPoolExecutor
import torch
from loguru import logger
from packaging.version import parse
from tqdm import tqdm
from lightx2v.utils.profiler import ExcludedProfilingContext
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class WeightAsyncStreamManager(object):
def __init__(self, offload_granularity):
self.offload_granularity = offload_granularity
self.init_stream = torch_device_module.Stream(priority=0)
self.need_init_first_buffer = True
self.lazy_load = False
torch_version = parse(torch.__version__.split("+")[0])
if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
self.cuda_load_stream = torch_device_module.Stream(priority=1)
self.compute_stream = torch_device_module.Stream(priority=1)
else:
self.cuda_load_stream = torch_device_module.Stream(priority=0)
self.compute_stream = torch_device_module.Stream(priority=-1)
def init_cpu_buffer(self, blocks_cpu_buffer=None, phases_cpu_buffer=None):
self.need_init_first_buffer = True
if self.offload_granularity == "block":
assert blocks_cpu_buffer is not None
self.cpu_buffers = [blocks_cpu_buffer[i] for i in range(len(blocks_cpu_buffer))]
elif self.offload_granularity == "phase":
assert phases_cpu_buffer is not None
self.cpu_buffers = [phases_cpu_buffer[i] for i in range(len(phases_cpu_buffer))]
else:
raise NotImplementedError
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
self.need_init_first_buffer = True
if self.offload_granularity == "block":
assert blocks_cuda_buffer is not None
self.cuda_buffers = [blocks_cuda_buffer[i] for i in range(len(blocks_cuda_buffer))]
elif self.offload_granularity == "phase":
assert phases_cuda_buffer is not None
self.cuda_buffers = [phases_cuda_buffer[i] for i in range(len(phases_cuda_buffer))]
else:
raise NotImplementedError
def init_first_buffer(self, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.init_stream):
if hasattr(self, "cpu_buffers"):
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
else:
if self.offload_granularity == "block":
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
else:
self.cuda_buffers[0].load_state_dict(blocks[0].compute_phases[0].state_dict(), 0, adapter_block_idx)
self.init_stream.synchronize()
self.need_init_first_buffer = False
def prefetch_weights(self, block_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cpu_buffers[1].load_state_dict_from_disk(block_idx, adapter_block_idx)
self.cuda_buffers[1].load_state_dict(self.cpu_buffers[1].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[1].load_state_dict(blocks[block_idx].state_dict(), block_idx, adapter_block_idx)
def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
def swap_blocks(self):
self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
self.cuda_buffers[0], self.cuda_buffers[1] = (
self.cuda_buffers[1],
self.cuda_buffers[0],
)
def swap_phases(self):
self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
@ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
def warm_up_cpu_buffers(self, blocks_num):
logger.info("🔥 Warming up cpu buffers...")
for i in tqdm(range(blocks_num)):
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(0, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(1, None)
logger.info("✅ CPU buffers warm-up completed.")
def init_lazy_load(self, num_workers=6):
self.lazy_load = True
self.executor = ThreadPoolExecutor(max_workers=num_workers)
self.prefetch_futures = []
self.prefetch_block_idx = -1
def start_prefetch_block(self, block_idx, adapter_block_idx=None):
self.prefetch_block_idx = block_idx
self.prefetch_futures = []
for phase in self.cpu_buffers[1]:
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
self.prefetch_futures.append(future)
def swap_cpu_buffers(self):
# wait_start = time.time()
# already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures:
f.result()
# wait_time = time.time() - wait_start
# logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def __del__(self):
if hasattr(self, "executor") and self.executor is not None:
for f in self.prefetch_futures:
if not f.done():
f.result()
self.executor.shutdown(wait=False)
self.executor = None
logger.debug("ThreadPoolExecutor shut down successfully.")
from .attn import *
from .conv import *
from .embedding import *
from .mm import *
from .norm import *
from .tensor import *
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight, SageAttn3Weight
from .spassage_attn import SageAttnWeight
from .svg2_attn import Svg2AttnWeight
from .svg_attn import SvgAttnWeight
from .torch_sdpa import TorchSDPAWeight
from .ulysses_attn import Ulysses4090AttnWeight, UlyssesAttnWeight
from loguru import logger
try:
import flash_attn # noqa: F401
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("flash_attn3")
class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
if model_cls is not None and model_cls in ["hunyuan_video_1.5"]:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment