Commit 3e4fe79b authored by GoatWu's avatar GoatWu
Browse files

Merge branch 'main' of github.com:ModelTC/lightx2v into main

parents 8ddd33a5 d013cac7
# Benchmark # 🚀 Benchmark
> This document showcases the performance test results of LightX2V across different hardware environments, including detailed comparison data for H200 and RTX 4090 platforms.
--- ---
## H200 (~140GB VRAM) ## 🖥️ H200 Environment (~140GB VRAM)
### 📋 Software Environment Configuration
**Software Environment:** | Component | Version |
- **Python**: 3.11 |:----------|:--------|
- **PyTorch**: 2.7.1+cu128 | **Python** | 3.11 |
- **SageAttention**: 2.2.0 | **PyTorch** | 2.7.1+cu128 |
- **vLLM**: 0.9.2 | **SageAttention** | 2.2.0 |
- **sgl-kernel**: 0.1.8 | **vLLM** | 0.9.2 |
| **sgl-kernel** | 0.1.8 |
### 480P 5s Video ---
### 🎬 480P 5s Video Test
**Test Configuration:** **Test Configuration:**
- **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) - **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
- **Parameters**: infer_steps=40, seed=42, enable_cfg=True - **Parameters**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
#### Performance Comparison #### 📊 Performance Comparison Table
| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect | | Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
|:-------------|:-----------------:|:--------------:|:-------:|:------------:| |:-------------|:-----------------:|:--------------:|:-------:|:------------:|
...@@ -29,13 +36,15 @@ ...@@ -29,13 +36,15 @@
| **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | <video src="https://github.com/user-attachments/assets/b4dc403c-919d-4ba1-b29f-ef53640c0334" width="200px"></video> | | **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | <video src="https://github.com/user-attachments/assets/b4dc403c-919d-4ba1-b29f-ef53640c0334" width="200px"></video> |
| **LightX2V_4** | 107 | 35 | **3.41x** | <video src="https://github.com/user-attachments/assets/49cd2760-4be2-432c-bf4e-01af9a1303dd" width="200px"></video> | | **LightX2V_4** | 107 | 35 | **3.41x** | <video src="https://github.com/user-attachments/assets/49cd2760-4be2-432c-bf4e-01af9a1303dd" width="200px"></video> |
### 720P 5s Video ---
### 🎬 720P 5s Video Test
**Test Configuration:** **Test Configuration:**
- **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) - **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
- **Parameters**: infer_steps=40, seed=1234, enable_cfg=True - **Parameters**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
#### Performance Comparison #### 📊 Performance Comparison Table
| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect | | Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
|:-------------|:-----------------:|:--------------:|:-------:|:------------:| |:-------------|:-----------------:|:--------------:|:-------:|:------------:|
...@@ -49,27 +58,92 @@ ...@@ -49,27 +58,92 @@
--- ---
## RTX 4090 (~24GB VRAM) ## 🖥️ RTX 4090 Environment (~24GB VRAM)
### 📋 Software Environment Configuration
| Component | Version |
|:----------|:--------|
| **Python** | 3.9.16 |
| **PyTorch** | 2.5.1+cu124 |
| **SageAttention** | 2.1.0 |
| **vLLM** | 0.6.6 |
| **sgl-kernel** | 0.0.5 |
| **q8-kernels** | 0.0.0 |
---
### 🎬 480P 5s Video Test
**Test Configuration:**
- **Model**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
- **Parameters**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
#### 📊 Performance Comparison Table
| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
|:-------------|:-----------------:|:--------------:|:-------:|:------------:|
| **Wan2GP(profile=3)** | 779 | 20 | **1.0x** | <video src="https://github.com/user-attachments/assets/ba548a48-04f8-4616-a55a-ad7aed07d438" width="200px"></video> |
| **LightX2V_5** | 738 | 16 | **1.05x** | <video src="https://github.com/user-attachments/assets/ce72ab7d-50a7-4467-ac8c-a6ed1b3827a7" width="200px"></video> |
| **LightX2V_5-Distill** | 68 | 16 | **11.45x** | <video src="https://github.com/user-attachments/assets/5df4b8a7-3162-47f8-a359-e22fbb4d1836" width="200px"></video> |
| **LightX2V_6** | 630 | 12 | **1.24x** | <video src="https://github.com/user-attachments/assets/d13cd939-363b-4f8b-80d9-d3a145c46676" width="200px"></video> |
| **LightX2V_6-Distill** | 63 | 12 | **🏆 12.36x** | <video src="https://github.com/user-attachments/assets/f372bce4-3c2f-411d-aa6b-c4daeb467d90" width="200px"></video>
---
### 🎬 720P 5s Video Test
**Test Configuration:**
- **Model**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
- **Parameters**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
#### 📊 Performance Comparison Table
| Configuration | Inference Time(s) | GPU Memory(GB) | Speedup | Video Effect |
|:-------------|:-----------------:|:--------------:|:-------:|:------------:|
| **Wan2GP(profile=3)** | -- | OOM | -- | <video src="--" width="200px"></video> |
| **LightX2V_5** | 2473 | 23 | -- | <video src="https://github.com/user-attachments/assets/0e83b146-3297-4c63-831c-8462cc657cad" width="200px"></video> |
| **LightX2V_5-Distill** | 183 | 23 | -- | <video src="https://github.com/user-attachments/assets/976d0af0-244c-4abe-b2cb-01f68ad69d3c" width="200px"></video> |
| **LightX2V_6** | 2169 | 18 | -- | <video src="https://github.com/user-attachments/assets/cf9edf82-53e1-46af-a000-79a88af8ad4a" width="200px"></video> |
| **LightX2V_6-Distill** | 171 | 18 | -- | <video src="https://github.com/user-attachments/assets/e3064b03-6cd6-4c82-9e31-ab28b3165798" width="200px"></video> |
---
## 📖 Configuration Descriptions
### 🖥️ H200 Environment Configuration Descriptions
### 480P 5s Video | Configuration | Technical Features |
|:--------------|:------------------|
| **Wan2.1 Official** | Based on [Wan2.1 official repository](https://github.com/Wan-Video/Wan2.1) original implementation |
| **FastVideo** | Based on [FastVideo official repository](https://github.com/hao-ai-lab/FastVideo), using SageAttention2 backend optimization |
| **LightX2V_1** | Uses SageAttention2 to replace native attention mechanism, adopts DIT BF16+FP32 (partial sensitive layers) mixed precision computation, improving computational efficiency while maintaining precision |
| **LightX2V_2** | Unified BF16 precision computation, further reducing memory usage and computational overhead while maintaining generation quality |
| **LightX2V_3** | Introduces FP8 quantization technology to significantly reduce computational precision requirements, combined with Tiling VAE technology to optimize memory usage |
| **LightX2V_3-Distill** | Based on LightX2V_3 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality |
| **LightX2V_4** | Based on LightX2V_3 with TeaCache(teacache_thresh=0.2) caching reuse technology, achieving acceleration through intelligent redundant computation skipping |
### 🖥️ RTX 4090 Environment Configuration Descriptions
| Configuration | Technical Features |
|:--------------|:------------------|
| **Wan2GP(profile=3)** | Implementation based on [Wan2GP repository](https://github.com/deepbeepmeep/Wan2GP), using MMGP optimization technology. Profile=3 configuration is suitable for RTX 3090/4090 environments with at least 32GB RAM and 24GB VRAM, adapting to limited memory resources by sacrificing VRAM. Uses quantized models: [480P model](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors) and [720P model](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors) |
| **LightX2V_5** | Uses SageAttention2 to replace native attention mechanism, adopts DIT FP8+FP32 (partial sensitive layers) mixed precision computation, enables CPU offload technology, executes partial sensitive layers with FP32 precision, asynchronously offloads DIT inference process data to CPU, saves VRAM, with block-level offload granularity |
| **LightX2V_5-Distill** | Based on LightX2V_5 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality |
| **LightX2V_6** | Based on LightX2V_3 with CPU offload technology enabled, executes partial sensitive layers with FP32 precision, asynchronously offloads DIT inference process data to CPU, saves VRAM, with block-level offload granularity |
| **LightX2V_6-Distill** | Based on LightX2V_6 using 4-step distillation model(`infer_steps=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality |
---
*Coming soon...* ## 📁 Configuration Files Reference
### 720P 5s Video Benchmark-related configuration files and execution scripts are available at:
*Coming soon...* | Type | Link | Description |
|:-----|:-----|:------------|
| **Configuration Files** | [configs/bench](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) | Contains JSON files with various optimization configurations |
| **Execution Scripts** | [scripts/bench](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) | Contains benchmark execution scripts |
--- ---
## Configuration Descriptions > 💡 **Tip**: It is recommended to choose the appropriate optimization solution based on your hardware configuration to achieve the best performance.
- **Wan2.1 Official**: Based on [Wan2.1 official repository](https://github.com/Wan-Video/Wan2.1)
- **FastVideo**: Based on [FastVideo official repository](https://github.com/hao-ai-lab/FastVideo), using SageAttention backend
- **LightX2V_1**: Uses SageAttention2 to replace native attention mechanism, adopts DIT BF16+FP32 (partial sensitive layers) mixed precision computation, improving computational efficiency while maintaining precision
- **LightX2V_2**: Unified BF16 precision computation, further reducing memory usage and computational overhead while maintaining generation quality
- **LightX2V_3**: Introduces FP8 quantization technology to significantly reduce computational precision requirements, combined with Tiling VAE technology to optimize memory usage
- **LightX2V_3-Distill**: Based on LightX2V_3 using 4-step distillation model(`infer_step=4`, `enable_cfg=False`), further reducing inference steps while maintaining generation quality.
- **LightX2V_4**: Based on LightX2V_3 with TeaCache(teacache_thresh=0.2) caching reuse technology, achieving acceleration through intelligent redundant computation skipping
- **Configuration Files Reference**: Benchmark-related configuration files and execution scripts are available at:
- [Configuration Files](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) - Contains JSON files with various optimization configurations
- [Execution Scripts](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) - Contains benchmark execution scripts
# Quick Start # LightX2V Quick Start Guide
## Prepare Environment Welcome to LightX2V! This guide will help you quickly set up the environment and start using LightX2V for video generation.
We recommend using a docker environment. Here is the [dockerhub](https://hub.docker.com/r/lightx2v/lightx2v/tags) for lightx2v. Please select the tag with the latest date, for example, 25061301. ## 📋 Table of Contents
```shell - [System Requirements](#system-requirements)
- [Linux Environment Setup](#linux-environment-setup)
- [Docker Environment (Recommended)](#docker-environment-recommended)
- [Conda Environment Setup](#conda-environment-setup)
- [Windows Environment Setup](#windows-environment-setup)
- [Inference Usage](#inference-usage)
## 🚀 System Requirements
- **Operating System**: Linux (Ubuntu 18.04+) or Windows 10/11
- **Python**: 3.10 or higher
- **GPU**: NVIDIA GPU with CUDA support, at least 8GB VRAM
- **Memory**: 16GB or more recommended
- **Storage**: At least 50GB available space
## 🐧 Linux Environment Setup
### 🐳 Docker Environment (Recommended)
We strongly recommend using the Docker environment, which is the simplest and fastest installation method.
#### 1. Pull Image
Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags) and select a tag with the latest date, such as `25061301`:
```bash
# Pull the latest version of LightX2V image
docker pull lightx2v/lightx2v:25061301 docker pull lightx2v/lightx2v:25061301
docker run --gpus all -itd --ipc=host --name [container_name] -v [mount_settings] --entrypoint /bin/bash [image_id]
``` ```
If you want to set up the environment yourself using conda, you can refer to the following steps: #### 2. Run Container
```bash
docker run --gpus all -itd --ipc=host --name [container_name] -v [mount_settings] --entrypoint /bin/bash [image_id]
```
#### 3. Domestic Mirror Source (Optional)
For users in mainland China, if the network is unstable when pulling images, you can pull from [Duduniao](https://docker.aityp.com/r/docker.io/lightx2v/lightx2v):
```bash
docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/lightx2v/lightx2v:25061301
```
### 🐍 Conda Environment Setup
If you prefer to set up the environment yourself using Conda, please follow these steps:
#### Step 1: Clone Repository
```bash
# Download project code
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
#### Step 2: Create Conda Virtual Environment
```bash
# Create and activate conda environment
conda create -n lightx2v python=3.12 -y
conda activate lightx2v
```
```shell #### Step 3: Install Dependencies
# clone repo and submodules
git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v
conda create -n lightx2v python=3.11 && conda activate lightx2v ```bash
# Install basic dependencies
pip install -r requirements.txt pip install -r requirements.txt
```
> 💡 **Note**: The Hunyuan model needs to run under transformers version 4.45.2. If you don't need to run the Hunyuan model, you can skip the transformers version restriction.
# The Hunyuan model needs to run under this version of transformers. If you do not need to run the Hunyuan model, you can ignore this step. #### Step 4: Install Attention Operators
# pip install transformers==4.45.2
# install flash-attention 2 **Option A: Flash Attention 2**
```bash
git clone https://github.com/Dao-AILab/flash-attention.git --recursive git clone https://github.com/Dao-AILab/flash-attention.git --recursive
cd flash-attention && python setup.py install cd flash-attention && python setup.py install
```
# install flash-attention 3, only if hopper **Option B: Flash Attention 3 (for Hopper architecture GPUs)**
```bash
cd flash-attention/hopper && python setup.py install cd flash-attention/hopper && python setup.py install
``` ```
## Infer **Option C: SageAttention 2 (Recommended)**
```bash
git clone https://github.com/thu-ml/SageAttention.git
cd SageAttention && python setup.py install
```
## 🪟 Windows Environment Setup
Windows systems only support Conda environment setup. Please follow these steps:
### 🐍 Conda Environment Setup
#### Step 1: Check CUDA Version
First, confirm your GPU driver and CUDA version:
```cmd
nvidia-smi
```
Record the **CUDA Version** information in the output, which needs to be consistent in subsequent installations.
#### Step 2: Create Python Environment
```cmd
# Create new environment (Python 3.12 recommended)
conda create -n lightx2v python=3.12 -y
# Activate environment
conda activate lightx2v
```
> 💡 **Note**: Python 3.10 or higher is recommended for best compatibility.
#### Step 3: Install PyTorch Framework
**Method 1: Download Official Wheel Package (Recommended)**
1. Visit the [PyTorch Official Download Page](https://download.pytorch.org/whl/torch/)
2. Select the corresponding version wheel package, paying attention to matching:
- **Python Version**: Consistent with your environment
- **CUDA Version**: Matches your GPU driver
- **Platform**: Select Windows version
```shell **Example (Python 3.12 + PyTorch 2.6 + CUDA 12.4):**
# Modify the path in the script
```cmd
# Download and install PyTorch
pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl
# Install supporting packages
pip install torchvision==0.21.0 torchaudio==2.6.0
```
**Method 2: Direct Installation via pip**
```cmd
# CUDA 12.4 version example
pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
```
#### Step 4: Install Windows Version vLLM
Download the corresponding wheel package from [vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases).
**Version Matching Requirements:**
- Python version matching
- PyTorch version matching
- CUDA version matching
```cmd
# Install vLLM (please adjust according to actual filename)
pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl
```
#### Step 5: Install Attention Mechanism Operators
**Option A: Flash Attention 2**
```cmd
pip install flash-attn==2.7.2.post1
```
**Option B: SageAttention 2 (Strongly Recommended)**
**Download Sources:**
- [Windows Special Version 1](https://github.com/woct0rdho/SageAttention/releases)
- [Windows Special Version 2](https://github.com/sdbds/SageAttention-for-windows/releases)
```cmd
# Install SageAttention (please adjust according to actual filename)
pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl
```
> ⚠️ **Note**: SageAttention's CUDA version doesn't need to be strictly aligned, but Python and PyTorch versions must match.
#### Step 6: Clone Repository
```cmd
# Clone project code
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
# Install Windows-specific dependencies
pip install -r requirements_win.txt
```
## 🎯 Inference Usage
### 📥 Model Preparation
Before starting inference, you need to download the model files in advance. We recommend:
- **Download Source**: Download models from [LightX2V Official Hugging Face](https://huggingface.co/lightx2v/) or other open-source model repositories
- **Storage Location**: It's recommended to store models on SSD disks for better read performance
- **Available Models**: Including Wan2.1-I2V, Wan2.1-T2V, and other models supporting different resolutions and functionalities
### 📁 Configuration Files and Scripts
The configuration files used for inference are available [here](https://github.com/ModelTC/LightX2V/tree/main/configs), and scripts are available [here](https://github.com/ModelTC/LightX2V/tree/main/scripts).
You need to configure the downloaded model path in the run script. In addition to the input arguments in the script, there are also some necessary parameters in the configuration file specified by `--config_json`. You can modify them as needed.
### 🚀 Start Inference
#### Linux Environment
```bash
# Run after modifying the path in the script
bash scripts/wan/run_wan_t2v.sh bash scripts/wan/run_wan_t2v.sh
``` ```
In addition to the existing input arguments in the script, there are also some necessary parameters in the `wan_t2v.json` file specified by `--config_json`. You can modify them as needed. #### Windows Environment
```cmd
# Use Windows batch script
scripts\win\run_wan_t2v.bat
```
## 📞 Get Help
If you encounter problems during installation or usage, please:
1. Search for related issues in [GitHub Issues](https://github.com/ModelTC/LightX2V/issues)
2. Submit a new Issue describing your problem
---
🎉 **Congratulations!** You have successfully set up the LightX2V environment and can now start enjoying video generation!
# Changing Resolution Inference # Variable Resolution Inference
## Overview ## Overview
Changing resolution inference is a technical strategy for optimizing the denoising process. It improves computational efficiency while maintaining generation quality by using different resolutions at different denoising stages. The core idea is to use lower resolution for rough denoising in the early stages of the denoising process, then switch to normal resolution for fine-tuning in the later stages. Variable resolution inference is a technical strategy for optimizing the denoising process. It improves computational efficiency while maintaining generation quality by using different resolutions at different stages of the denoising process. The core idea of this method is to use lower resolution for coarse denoising in the early stages and switch to normal resolution for fine processing in the later stages.
## Technical Principles ## Technical Principles
### Phased Denoising Strategy ### Multi-stage Denoising Strategy
Changing resolution inference is based on the following observations: Variable resolution inference is based on the following observations:
- **Early-stage denoising**: Mainly processes rough noise and overall structure, doesn't require excessive detail information
- **Late-stage denoising**: Focuses on detail optimization and high-frequency information recovery, requires complete resolution information - **Early-stage denoising**: Mainly handles coarse noise and overall structure, requiring less detailed information
- **Late-stage denoising**: Focuses on detail optimization and high-frequency information recovery, requiring complete resolution information
### Resolution Switching Mechanism ### Resolution Switching Mechanism
1. **Low Resolution Stage** (Early stage) 1. **Low-resolution stage** (early stage)
- Downsample the input to lower resolution (e.g., 0.75 of original size) - Downsample the input to a lower resolution (e.g., 0.75x of original size)
- Execute initial denoising steps - Execute initial denoising steps
- Quickly remove most noise and establish basic structure - Quickly remove most noise and establish basic structure
2. **Normal Resolution Stage** (Late stage) 2. **Normal resolution stage** (late stage)
- Upsample the denoising result from the first step back to original resolution - Upsample the denoising result from the first step back to original resolution
- Continue executing remaining denoising steps - Continue executing remaining denoising steps
- Recover detail information and complete fine-tuning - Restore detailed information and complete fine processing
### U-shaped Resolution Strategy
If resolution is reduced at the very beginning of the denoising steps, it may cause significant differences between the final generated video and the video generated through normal inference. Therefore, a U-shaped resolution strategy can be adopted, where the original resolution is maintained for the first few steps, then resolution is reduced for inference.
## Usage ## Usage
The config files for changing resolution inference are available [here](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution) The config files for variable resolution inference are located [here](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution)
You can test variable resolution inference by specifying --config_json to the specific config file.
You can refer to the scripts [here](https://github.com/ModelTC/LightX2V/blob/main/scripts/changing_resolution) to run.
### Example 1:
```
{
"infer_steps": 50,
"changing_resolution": true,
"resolution_rate": [0.75],
"changing_resolution_steps": [25]
}
```
This means a total of 50 steps, with resolution at 0.75x original resolution from step 0 to 25, and original resolution from step 26 to the final step.
By specifying --config_json to the specific config file, you can test changing resolution inference. ### Example 2:
```
{
"infer_steps": 50,
"changing_resolution": true,
"resolution_rate": [1.0, 0.75],
"changing_resolution_steps": [10, 35]
}
```
You can refer to [this script](https://github.com/ModelTC/LightX2V/blob/main/scripts/wan/run_wan_t2v_changing_resolution.sh). This means a total of 50 steps, with original resolution from step 0 to 10, 0.75x original resolution from step 11 to 35, and original resolution from step 36 to the final step.
# Gradio 部署 # Gradio 部署指南
## 📖 概述 ## 📖 概述
Lightx2v 是一个轻量级的视频推理和生成引擎,提供了基于 Gradio 的 Web 界面,支持图像到视频(Image-to-Video)和文本到视频(Text-to-Video)两种生成模式。 Lightx2v 是一个轻量级的视频推理和生成引擎,提供基于 Gradio 的 Web 界面,支持图像到视频(Image-to-Video)和文本到视频(Text-to-Video)两种生成模式。
## 📁 文件结构
```
LightX2V/app/
├── gradio_demo.py # 英文界面演示
├── gradio_demo_zh.py # 中文界面演示
├── run_gradio.sh # 启动脚本
├── README.md # 说明文档
├── saved_videos/ # 生成视频保存目录
└── inference_logs.log # 推理日志
```
本项目包含两个主要演示文件: 本项目包含两个主要演示文件:
- `gradio_demo.py` - 英文界面版本 - `gradio_demo.py` - 英文界面版本
...@@ -12,27 +24,17 @@ Lightx2v 是一个轻量级的视频推理和生成引擎,提供了基于 Grad ...@@ -12,27 +24,17 @@ Lightx2v 是一个轻量级的视频推理和生成引擎,提供了基于 Grad
### 环境要求 ### 环境要求
- Python 3.10+ (推荐) 按照[快速开始文档](../getting_started/quickstart.md)安装环境
- CUDA 12.4+ (推荐)
- 至少 8GB GPU 显存
- 至少 16GB 系统内存(最好最少有 32G)
- 至少 128GB SSD固态硬盘 (**💾 强烈建议使用SSD固态硬盘存储模型文件!"延迟加载"启动时,显著提升模型加载速度和推理性能**)
### 安装依赖☀
```bash
# 安装基础依赖
pip install -r requirements.txt
pip install gradio
```
#### 推荐优化库配置 #### 推荐优化库配置
-[Flash attention](https://github.com/Dao-AILab/flash-attention) -[Flash attention](https://github.com/Dao-AILab/flash-attention)
-[Sage attention](https://github.com/thu-ml/SageAttention) -[Sage attention](https://github.com/thu-ml/SageAttention)
-[vllm-kernel](https://github.com/vllm-project/vllm) -[vllm-kernel](https://github.com/vllm-project/vllm)
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) -[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (只支持ADA架构的GPU) -[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (仅支持ADA架构的GPU)
可根据需要,按照各算子的项目主页教程进行安装
### 🤖 支持的模型 ### 🤖 支持的模型
...@@ -53,19 +55,22 @@ pip install gradio ...@@ -53,19 +55,22 @@ pip install gradio
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-Lightx2v) | 14B | 标准版本 | 平衡速度和质量 | | ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-Lightx2v) | 14B | 标准版本 | 平衡速度和质量 |
| ✅ [Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v) | 14B | 蒸馏优化版 | 高质量+快速推理 | | ✅ [Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v) | 14B | 蒸馏优化版 | 高质量+快速推理 |
**💡 模型选择建议**: **💡 模型选择建议**:
- **首次使用**: 建议选择蒸馏版本 - **首次使用**: 建议选择蒸馏版本 (`wan2.1_distill`)
- **追求质量**: 选择720p分辨率或14B参数模型 - **追求质量**: 选择720p分辨率或14B参数模型
- **追求速度**: 选择480p分辨率或1.3B参数模型 - **追求速度**: 选择480p分辨率或1.3B参数模型,优先使用蒸馏版本
- **资源受限**: 优先选择蒸馏版本和较低分辨率 - **资源受限**: 优先选择蒸馏版本和较低分辨率
- **实时应用**: 强烈推荐使用蒸馏模型 (`wan2.1_distill`)
**🎯 模型类别说明**:
- **`wan2.1`**: 标准模型,提供最佳的视频生成质量,适合对质量要求极高的场景
- **`wan2.1_distill`**: 蒸馏模型,通过知识蒸馏技术优化,推理速度显著提升,在保持良好质量的同时大幅减少计算时间,适合大多数应用场景
### 启动方式 ### 启动方式
#### 方式一:使用启动脚本(推荐) #### 方式一:使用启动脚本(推荐)
**Linux 环境:**
```bash ```bash
# 1. 编辑启动脚本,配置相关路径 # 1. 编辑启动脚本,配置相关路径
cd app/ cd app/
...@@ -82,41 +87,84 @@ vim run_gradio.sh ...@@ -82,41 +87,84 @@ vim run_gradio.sh
# 2. 运行启动脚本 # 2. 运行启动脚本
bash run_gradio.sh bash run_gradio.sh
# 3. 或使用参数启动(推荐) # 3. 或使用参数启动(推荐使用蒸馏模型)
bash run_gradio.sh --task i2v --lang zh --model_size 14b --port 8032 bash run_gradio.sh --task i2v --lang zh --model_cls wan2.1 --model_size 14b --port 8032
# bash run_gradio.sh --task i2v --lang zh --model_size 14b --port 8032 bash run_gradio.sh --task t2v --lang zh --model_cls wan2.1 --model_size 1.3b --port 8032
# bash run_gradio.sh --task i2v --lang zh --model_size 1.3b --port 8032 bash run_gradio.sh --task i2v --lang zh --model_cls wan2.1_distill --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang zh --model_cls wan2.1_distill --model_size 1.3b --port 8032
```
**Windows 环境:**
```cmd
# 1. 编辑启动脚本,配置相关路径
cd app\
notepad run_gradio_win.bat
# 需要修改的配置项:
# - lightx2v_path: Lightx2v项目根目录路径
# - i2v_model_path: 图像到视频模型路径
# - t2v_model_path: 文本到视频模型路径
# 💾 重要提示:建议将模型路径指向SSD存储位置
# 例如:D:\models\ 或 E:\models\
# 2. 运行启动脚本
run_gradio_win.bat
# 3. 或使用参数启动(推荐使用蒸馏模型)
run_gradio_win.bat --task i2v --lang zh --model_cls wan2.1 --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang zh --model_cls wan2.1 --model_size 1.3b --port 8032
run_gradio_win.bat --task i2v --lang zh --model_cls wan2.1_distill --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang zh --model_cls wan2.1_distill --model_size 1.3b --port 8032
``` ```
#### 方式二:直接命令行启动 #### 方式二:直接命令行启动
**Linux 环境:**
**图像到视频模式:** **图像到视频模式:**
```bash ```bash
python gradio_demo_zh.py \ python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \ --model_path /path/to/Wan2.1-I2V-14B-480P-Lightx2v \
--model_cls wan2.1 \
--model_size 14b \ --model_size 14b \
--task i2v \ --task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
``` ```
**文本到视频模式:** **英文界面版本:**
```bash ```bash
python gradio_demo_zh.py \ python gradio_demo.py \
--model_path /path/to/Wan2.1-T2V-1.3B \ --model_path /path/to/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v \
--model_size 1.3b \ --model_cls wan2.1_distill \
--model_size 14b \
--task t2v \ --task t2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
``` ```
**Windows 环境:**
**图像到视频模式:**
```cmd
python gradio_demo_zh.py ^
--model_path D:\models\Wan2.1-I2V-14B-480P-Lightx2v ^
--model_cls wan2.1 ^
--model_size 14b ^
--task i2v ^
--server_name 127.0.0.1 ^
--server_port 7862
```
**英文界面版本:** **英文界面版本:**
```bash ```cmd
python gradio_demo.py \ python gradio_demo_zh.py ^
--model_path /path/to/model \ --model_path D:\models\Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v ^
--model_size 14b \ --model_cls wan2.1_distill ^
--task i2v \ --model_size 14b ^
--server_name 0.0.0.0 \ --task i2v ^
--server_name 127.0.0.1 ^
--server_port 7862 --server_port 7862
``` ```
...@@ -125,8 +173,8 @@ python gradio_demo.py \ ...@@ -125,8 +173,8 @@ python gradio_demo.py \
| 参数 | 类型 | 必需 | 默认值 | 说明 | | 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------| |------|------|------|--------|------|
| `--model_path` | str | ✅ | - | 模型文件夹路径 | | `--model_path` | str | ✅ | - | 模型文件夹路径 |
| `--model_cls` | str | ❌ | wan2.1 | 模型类别(目前仅支持wan2.1) | | `--model_cls` | str | ❌ | wan2.1 | 模型类别`wan2.1`(标准模型)或 `wan2.1_distill`(蒸馏模型,推理更快) |
| `--model_size` | str | ✅ | - | 模型大小:`14b(图像到视频或者文本到视频)``1.3b(文本到视频)` | | `--model_size` | str | ✅ | - | 模型大小:`14b``1.3b)` |
| `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) | | `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) |
| `--server_port` | int | ❌ | 7862 | 服务器端口 | | `--server_port` | int | ❌ | 7862 | 服务器端口 |
| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 | | `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 |
...@@ -178,7 +226,6 @@ python gradio_demo.py \ ...@@ -178,7 +226,6 @@ python gradio_demo.py \
启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数: 启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数:
### GPU内存规则 ### GPU内存规则
- **80GB+**: 默认配置,无需优化 - **80GB+**: 默认配置,无需优化
- **48GB**: 启用CPU卸载,卸载比例50% - **48GB**: 启用CPU卸载,卸载比例50%
...@@ -201,23 +248,11 @@ python gradio_demo.py \ ...@@ -201,23 +248,11 @@ python gradio_demo.py \
**💡 针对显存不足或性能受限的设备**: **💡 针对显存不足或性能受限的设备**:
- **🎯 模型选择**: 优先使用蒸馏版本模型 (StepDistill-CfgDistill) - **🎯 模型选择**: 优先使用蒸馏版本模型 (`wan2.1_distill`)
- **⚡ 推理步数**: 建议设置为 4 步 - **⚡ 推理步数**: 建议设置为 4 步
- **🔧 CFG设置**: 建议关闭CFG选项以提升生成速度 - **🔧 CFG设置**: 建议关闭CFG选项以提升生成速度
- **🔄 自动配置**: 启用"自动配置推理选项" - **🔄 自动配置**: 启用"自动配置推理选项"
- **💾 存储优化**: 确保模型存储在SSD上以获得最佳加载性能
## 📁 文件结构
```
lightx2v/app/
├── gradio_demo.py # 英文界面演示
├── gradio_demo_zh.py # 中文界面演示
├── run_gradio.sh # 启动脚本
├── README.md # 说明文档
├── saved_videos/ # 生成视频保存目录
└── inference_logs.log # 推理日志
```
## 🎨 界面说明 ## 🎨 界面说明
...@@ -244,12 +279,12 @@ lightx2v/app/ ...@@ -244,12 +279,12 @@ lightx2v/app/
- 降低分辨率 - 降低分辨率
- 启用量化选项 - 启用量化选项
1. **系内存不足** 2. **系内存不足**
- 启用CPU卸载 - 启用CPU卸载
- 启用延迟加载选项 - 启用延迟加载选项
- 启用量化选项 - 启用量化选项
2. **生成速度慢** 3. **生成速度慢**
- 减少推理步数 - 减少推理步数
- 启用自动配置 - 启用自动配置
- 使用轻量级模型 - 使用轻量级模型
...@@ -257,13 +292,13 @@ lightx2v/app/ ...@@ -257,13 +292,13 @@ lightx2v/app/
- 使用量化算子 - 使用量化算子
- 💾 **检查模型是否存放在SSD上** - 💾 **检查模型是否存放在SSD上**
3. **模型加载缓慢** 4. **模型加载缓慢**
- 💾 **将模型迁移到SSD存储** - 💾 **将模型迁移到SSD存储**
- 启用延迟加载选项 - 启用延迟加载选项
- 检查磁盘I/O性能 - 检查磁盘I/O性能
- 考虑使用NVMe SSD - 考虑使用NVMe SSD
4. **视频质量不佳** 5. **视频质量不佳**
- 增加推理步数 - 增加推理步数
- 提高CFG缩放因子 - 提高CFG缩放因子
- 使用14B模型 - 使用14B模型
...@@ -282,8 +317,6 @@ nvidia-smi ...@@ -282,8 +317,6 @@ nvidia-smi
htop htop
``` ```
欢迎提交Issue和Pull Request来改进这个项目! 欢迎提交Issue和Pull Request来改进这个项目!
**注意**: 使用本工具生成的视频内容请遵守相关法律法规,不得用于非法用途。 **注意**: 使用本工具生成的视频内容请遵守相关法律法规,不得用于非法用途。
# 本地Windows电脑部署指南 # Windows 本地部署指南
本文档将详细指导您在Windows环境下完成LightX2V的本地部署配置。 ## 📖 概述
## 系统要求 本文档将详细指导您在Windows环境下完成LightX2V的本地部署配置,包括批处理文件推理、Gradio Web界面推理等多种使用方式。
在开始之前,请确保您的系统满足以下要求: ## 🚀 快速开始
- **操作系统**: Windows 10/11 ### 环境要求
- **显卡**: NVIDIA GPU(支持CUDA)
- **显存**: 至少8GB显存
- **内存**: 至少16GB内存
- **存储空间**: 20GB以上可用硬盘空间
- **环境管理**: 已安装Anaconda或Miniconda
- **网络工具**: Git(用于克隆代码仓库)
## 部署步骤 #### 硬件要求
- **GPU**: NVIDIA GPU,建议 8GB+ VRAM
- **内存**: 建议 16GB+ RAM
- **存储**: 强烈建议使用 SSD 固态硬盘,机械硬盘会导致模型加载缓慢
### 步骤1:检查CUDA版本 ## 🎯 使用方式
首先确认您的GPU驱动和CUDA版本,在命令提示符中运行: ### 方式一:使用批处理文件推理
```bash 参考[快速开始文档](../getting_started/quickstart.md)安装环境,并使用[批处理文件](https://github.com/ModelTC/LightX2V/tree/main/scripts/win)运行。
nvidia-smi
```
记录输出中显示的**CUDA Version**信息,后续安装时需要保持版本一致。
### 步骤2:创建Python环境 ### 方式二:使用Gradio Web界面推理
创建一个独立的conda环境,推荐使用Python 3.12: #### 手动配置Gradio
```bash 参考[快速开始文档](../getting_started/quickstart.md)安装环境,参考[Gradio部署指南](./deploy_gradio.md)
# 创建新环境(以Python 3.12为例)
conda create -n lightx2v python=3.12 -y
# 激活环境
conda activate lightx2v
```
> 💡 **提示**: 建议使用Python 3.10或更高版本以获得最佳兼容性。 #### 一键启动Gradio(推荐)
### 步骤3:安装PyTorch框架 **📦 下载软件包**
- [百度云](https://pan.baidu.com/s/1ef3hEXyIuO0z6z9MoXe4nQ?pwd=7g4f)
- [夸克网盘](https://pan.quark.cn/s/36a0cdbde7d9)
#### 方法一:下载官方wheel包安装(推荐) **📁 目录结构**
解压后,确保目录结构如下:
1. 访问 [PyTorch官方wheel包下载页面](https://download.pytorch.org/whl/torch/)
2. 选择对应版本的wheel包,注意匹配:
- **Python版本**: 与您的环境一致(cp312表示Python 3.12)
- **CUDA版本**: 与您的GPU驱动匹配
- **平台**: 选择Windows版本(win_amd64)
**以Python 3.12 + PyTorch 2.6 + CUDA 12.4为例:**
```
torch-2.6.0+cu124-cp312-cp312-win_amd64.whl
``` ```
├── env/ # LightX2V 环境目录
下载完成后进行安装: ├── LightX2V/ # LightX2V 项目目录
├── start_lightx2v.bat # 一键启动脚本
```bash ├── lightx2v_config.txt # 配置文件
# 安装PyTorch(请替换为实际的文件路径) ├── LightX2V使用说明.txt # LightX2V使用说明
pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl └── models/ # 模型存放目录
├── 说明.txt # 模型说明文档
# 安装配套的vision和audio包 ├── Wan2.1-I2V-14B-480P-Lightx2v/ # 图像转视频模型(480P)
pip install torchvision==0.21.0 torchaudio==2.6.0 ├── Wan2.1-I2V-14B-720P-Lightx2v/ # 图像转视频模型(720P)
├── Wan2.1-I2V-14B-480P-StepDistill-CfgDistil-Lightx2v/ # 图像转视频模型(4步蒸馏,480P)
├── Wan2.1-I2V-14B-720P-StepDistill-CfgDistil-Lightx2v/ # 图像转视频模型(4步蒸馏,720P)
├── Wan2.1-T2V-1.3B-Lightx2v/ # 文本转视频模型(1.3B参数)
├── Wan2.1-T2V-14B-Lightx2v/ # 文本转视频模型(14B参数)
└── Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/ # 文本转视频模型(4步蒸馏)
``` ```
#### 方法二:使用pip直接安装 **📋 配置参数**
如果您偏好直接安装,可以使用以下命令 编辑 `lightx2v_config.txt` 文件,根据需要修改以下参数
```bash ```ini
# 示例:CUDA 12.4版本 # 任务类型 (i2v: 图像转视频, t2v: 文本转视频)
pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124 task=i2v
```
### 步骤4:安装Windows版vLLM # 界面语言 (zh: 中文, en: 英文)
lang=zh
[vllm-windows releases页面](https://github.com/SystemPanic/vllm-windows/releases) 下载对应的wheel包。 # 服务器端口
port=8032
**版本匹配要求:** # GPU设备ID (0, 1, 2...)
- Python版本匹配(如cp312) gpu=0
- PyTorch版本匹配
- CUDA版本匹配
**推荐安装v0.9.1版本:** # 模型大小 (14b: 14B参数模型, 1.3b: 1.3B参数模型)
model_size=14b
```bash # 模型类别 (wan2.1: 标准模型, wan2.1_distill: 蒸馏模型)
pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl model_cls=wan2.1
``` ```
> ⚠️ **注意**: 请根据您的具体环境选择对应的wheel包文件名。 **⚠️ 重要提示**: 如果使用蒸馏模型(模型名称包含StepDistill-CfgDistil字段),请将`model_cls`设置为`wan2.1_distill`
### 步骤5:安装注意力机制算子
您可以选择安装Flash Attention 2或SageAttention 2,**强烈推荐SageAttention 2**
#### 选项A:Flash Attention 2
```bash **🚀 启动服务**
pip install flash-attn==2.7.2.post1
```
#### 选项B:SageAttention 2(推荐)
**下载源选择:** 双击运行 `start_lightx2v.bat` 文件,脚本将:
- [Windows专用版本1](https://github.com/woct0rdho/SageAttention/releases) 1. 自动读取配置文件
- [Windows专用版本2](https://github.com/sdbds/SageAttention-for-windows/releases) 2. 验证模型路径和文件完整性
3. 启动 Gradio Web 界面
4. 自动打开浏览器访问服务
**版本选择要点:** **💡 使用建议**: 当打开Gradio Web页面后,建议勾选"自动配置推理选项",系统会自动选择合适的优化配置针对您的机器。当重新选择分辨率后,也需要重新勾选"自动配置推理选项"。
- Python版本必须匹配
- PyTorch版本必须匹配
- **CUDA版本可以不严格对齐**(SageAttention暂未使用破坏性API)
**推荐安装版本:** **⚠️ 重要提示**: 首次运行时会自动解压环境文件 `env.zip`,此过程需要几分钟时间,请耐心等待。后续启动无需重复此步骤。您也可以手动解压 `env.zip` 文件到当前目录以节省首次启动时间。
```bash
pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl
```
**验证SageAttention安装:** ### 方式三:使用ComfyUI推理
> 📝 **测试**: 您也可以运行[测试脚本](https://github.com/woct0rdho/SageAttention/blob/main/tests/test_sageattn.py)进行更详细的功能验证 此说明将指导您如何下载与使用便携版的Lightx2v-ComfyUI环境,如此可以免去手动配置环境的步骤,适用于想要在Windows系统下快速开始体验使用Lightx2v加速视频生成的用户
### 步骤6:获取LightX2V项目代码 #### 下载Windows便携环境:
从GitHub克隆LightX2V项目并安装Windows专用依赖: - [百度网盘下载](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid),提取码:wfid
```bash 便携环境中已经打包了所有Python运行相关的依赖,也包括ComfyUI和LightX2V的代码及其相关依赖,下载后解压即可使用。
# 克隆项目代码
git clone https://github.com/ModelTC/LightX2V.git
# 进入项目目录 解压后对应的文件目录说明如下:
cd LightX2V
# 安装Windows专用依赖包 ```shell
pip install -r requirements_win.txt lightx2v_env
├──📂 ComfyUI # ComfyUI代码
├──📂 portable_python312_embed # 独立的Python环境
└── run_nvidia_gpu.bat # Windows启动脚本(双击启动)
``` ```
> 🔍 **说明**: 这里使用`requirements_win.txt`而不是标准的`requirements.txt`,因为Windows环境可能需要特定的包版本或额外的依赖。 #### 启动ComfyUI
## 故障排除
### 1. CUDA版本不匹配
**问题现象**: 出现CUDA相关错误
**解决方案**:
- 确认GPU驱动支持所需CUDA版本
- 重新下载匹配的wheel包
- 可以通过`nvidia-smi`查看支持的最高CUDA版本
### 2. 依赖冲突
**问题现象**: 包版本冲突或导入错误
**解决方案**:
- 删除现有环境: `conda env remove -n lightx2v`
- 重新创建环境并严格按版本要求安装
- 使用虚拟环境隔离不同项目的依赖
### 3. wheel包下载问题
**问题现象**: 下载速度慢或失败
**解决方案**:
- 使用下载工具或浏览器直接下载
- 寻找国内镜像源
- 检查网络连接和防火墙设置
直接双击run_nvidia_gpu.bat文件,系统会打开一个Command Prompt窗口并运行程序,一般第一次启动时间会比较久,请耐心等待,启动完成后会自动打开浏览器并出现ComfyUI的前端界面。
## 下一步操作 ![i2v示例工作流](../../../../assets/figs/portabl_windows/pic1.png)
环境配置完成后,您可以: LightX2V-ComfyUI的插件使用的是,[ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper),示例工作流可以从此项目中获取。
- 📚 查看[快速开始指南](../getting_started/quickstart.md)(跳过环境安装步骤) #### 已测试显卡(offload模式)
- 🌐 使用[Gradio Web界面](./deploy_gradio.md)进行可视化操作(跳过环境安装步骤)
## 版本兼容性参考 - 测试模型`Wan2.1-I2V-14B-480P`
| 组件 | 推荐版本 | | 显卡型号 | 任务类型 | 显存容量 | 实际最大显存占用 | 实际最大内存占用 |
|------|----------| |:----------|:-----------|:-----------|:-------- |:---------- |
| Python | 3.12 | | 3090Ti | I2V | 24G | 6.1G | 7.1G |
| PyTorch | 2.6.0+cu124 | | 3080Ti | I2V | 12G | 6.1G | 7.1G |
| vLLM | 0.9.1+cu124 | | 3060Ti | I2V | 8G | 6.1G | 7.1G |
| SageAttention | 2.1.1+cu126torch2.6.0 |
| CUDA | 12.4+ |
---
💡 **小贴士**: 如果遇到其他问题,建议先检查各组件版本是否匹配,大部分问题都源于版本不兼容。 #### 环境打包和使用参考
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2)
# 低延迟场景部署 # 低延迟场景部署
xxx 在低延迟的场景,我们会追求更快的速度,忽略显存和内存开销等问题。我们提供两套方案:
## 💡 方案一:步数蒸馏模型的推理
该方案可以参考[步数蒸馏文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html)
🧠 **步数蒸馏**是非常直接的视频生成模型的加速推理方案。从50步蒸馏到4步,耗时将缩短到原来的4/50。同时,该方案下,仍然可以和以下方案结合使用:
1. [高效注意力机制方案](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html)
2. [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html)
## 💡 方案二:非步数蒸馏模型的推理
步数蒸馏需要比较大的训练资源,以及步数蒸馏后的模型,可能会出现视频动态范围变差的问题。
对于非步数蒸馏的原始模型,我们可以使用以下方案或者多种方案结合的方式进行加速:
1. [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) 进行多卡并行加速。
2. [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) 降低实际推理步数。
3. [高效注意力机制方案](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) 加速 Attention 的推理。
4. [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) 加速 Linear 层的推理。
5. [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) 降低中间推理步的分辨率。
## ⚠️ 注意
有一部分的加速方案之间目前无法结合使用,我们目前正在致力于解决这一问题。
如有问题,欢迎在 [🐛 GitHub Issues](https://github.com/ModelTC/lightx2v/issues) 中进行错误报告或者功能请求
# 基准测试 # 🚀 基准测试
> 本文档展示了LightX2V在不同硬件环境下的性能测试结果,包括H200和RTX 4090平台的详细对比数据。
--- ---
## H200 (~140GB显存) ## 🖥️ H200 环境 (~140GB显存)
### 📋 软件环境配置
**软件环境配置:** | 组件 | 版本 |
- **Python**: 3.11 |:-----|:-----|
- **PyTorch**: 2.7.1+cu128 | **Python** | 3.11 |
- **SageAttention**: 2.2.0 | **PyTorch** | 2.7.1+cu128 |
- **vLLM**: 0.9.2 | **SageAttention** | 2.2.0 |
- **sgl-kernel**: 0.1.8 | **vLLM** | 0.9.2 |
| **sgl-kernel** | 0.1.8 |
### 480P 5s视频 ---
### 🎬 480P 5s视频测试
**测试配置:** **测试配置:**
- **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) - **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
- **参数**: infer_steps=40, seed=42, enable_cfg=True - **参数**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
#### 性能对比 #### 📊 性能对比
| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 | | 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
|:-----|:----------:|:---------------:|:------:|:--------:| |:-----|:----------:|:---------------:|:------:|:--------:|
...@@ -29,14 +36,15 @@ ...@@ -29,14 +36,15 @@
| **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | <video src="https://github.com/user-attachments/assets/b4dc403c-919d-4ba1-b29f-ef53640c0334" width="200px"></video> | | **LightX2V_3-Distill** | 14 | 35 | **🏆 20.85x** | <video src="https://github.com/user-attachments/assets/b4dc403c-919d-4ba1-b29f-ef53640c0334" width="200px"></video> |
| **LightX2V_4** | 107 | 35 | **3.41x** | <video src="https://github.com/user-attachments/assets/49cd2760-4be2-432c-bf4e-01af9a1303dd" width="200px"></video> | | **LightX2V_4** | 107 | 35 | **3.41x** | <video src="https://github.com/user-attachments/assets/49cd2760-4be2-432c-bf4e-01af9a1303dd" width="200px"></video> |
### 720P 5s视频 ---
### 🎬 720P 5s视频测试
**测试配置:** **测试配置:**
- **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) - **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
- **参数**: infer_steps=40, seed=1234, enable_cfg=True - **参数**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
#### 性能对比
#### 📊 性能对比表
| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 | | 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
|:-----|:----------:|:---------------:|:------:|:--------:| |:-----|:----------:|:---------------:|:------:|:--------:|
...@@ -50,27 +58,92 @@ ...@@ -50,27 +58,92 @@
--- ---
## RTX 4090 (~24GB显存) ## 🖥️ RTX 4090 环境 (~24GB显存)
### 📋 软件环境配置
| 组件 | 版本 |
|:-----|:-----|
| **Python** | 3.9.16 |
| **PyTorch** | 2.5.1+cu124 |
| **SageAttention** | 2.1.0 |
| **vLLM** | 0.6.6 |
| **sgl-kernel** | 0.0.5 |
| **q8-kernels** | 0.0.0 |
---
### 🎬 480P 5s视频测试
**测试配置:**
- **模型**: [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v)
- **参数**: `infer_steps=40`, `seed=42`, `enable_cfg=True`
#### 📊 性能对比表
| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
|:-----|:----------:|:---------------:|:------:|:--------:|
| **Wan2GP(profile=3)** | 779 | 20 | **1.0x** | <video src="https://github.com/user-attachments/assets/ba548a48-04f8-4616-a55a-ad7aed07d438" width="200px"></video> |
| **LightX2V_5** | 738 | 16 | **1.05x** | <video src="https://github.com/user-attachments/assets/ce72ab7d-50a7-4467-ac8c-a6ed1b3827a7" width="200px"></video> |
| **LightX2V_5-Distill** | 68 | 16 | **11.45x** | <video src="https://github.com/user-attachments/assets/5df4b8a7-3162-47f8-a359-e22fbb4d1836" width="200px"></video> |
| **LightX2V_6** | 630 | 12 | **1.24x** | <video src="https://github.com/user-attachments/assets/d13cd939-363b-4f8b-80d9-d3a145c46676" width="200px"></video> |
| **LightX2V_6-Distill** | 63 | 12 | **🏆 12.36x** | <video src="https://github.com/user-attachments/assets/f372bce4-3c2f-411d-aa6b-c4daeb467d90" width="200px"></video> |
### 480P 5s视频 ---
### 🎬 720P 5s视频测试
**测试配置:**
- **模型**: [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v)
- **参数**: `infer_steps=40`, `seed=1234`, `enable_cfg=True`
#### 📊 性能对比表
| 配置 | 推理时间(s) | GPU显存占用(GB) | 加速比 | 视频效果 |
|:-----|:----------:|:---------------:|:------:|:--------:|
| **Wan2GP(profile=3)** | -- | OOM | -- | <video src="--" width="200px"></video> |
| **LightX2V_5** | 2473 | 23 | -- | <video src="https://github.com/user-attachments/assets/0e83b146-3297-4c63-831c-8462cc657cad" width="200px"></video> |
| **LightX2V_5-Distill** | 183 | 23 | -- | <video src="https://github.com/user-attachments/assets/976d0af0-244c-4abe-b2cb-01f68ad69d3c" width="200px"></video> |
| **LightX2V_6** | 2169 | 18 | -- | <video src="https://github.com/user-attachments/assets/cf9edf82-53e1-46af-a000-79a88af8ad4a" width="200px"></video> |
| **LightX2V_6-Distill** | 171 | 18 | -- | <video src="https://github.com/user-attachments/assets/e3064b03-6cd6-4c82-9e31-ab28b3165798" width="200px"></video> |
---
## 📖 配置说明
### 🖥️ H200 环境配置说明
| 配置 | 技术特点 |
|:-----|:---------|
| **Wan2.1 Official** | 基于[Wan2.1官方仓库](https://github.com/Wan-Video/Wan2.1)的原始实现 |
| **FastVideo** | 基于[FastVideo官方仓库](https://github.com/hao-ai-lab/FastVideo),使用SageAttention2后端优化 |
| **LightX2V_1** | 使用SageAttention2替换原生注意力机制,采用DIT BF16+FP32(部分敏感层)混合精度计算,在保持精度的同时提升计算效率 |
| **LightX2V_2** | 统一使用BF16精度计算,进一步减少显存占用和计算开销,同时保持生成质量 |
| **LightX2V_3** | 引入FP8量化技术显著减少计算精度要求,结合Tiling VAE技术优化显存使用 |
| **LightX2V_3-Distill** | 在LightX2V_3基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 |
| **LightX2V_4** | 在LightX2V_3基础上加入TeaCache(teacache_thresh=0.2)缓存复用技术,通过智能跳过冗余计算实现加速 |
### 🖥️ RTX 4090 环境配置说明
| 配置 | 技术特点 |
|:-----|:---------|
| **Wan2GP(profile=3)** | 基于[Wan2GP仓库](https://github.com/deepbeepmeep/Wan2GP)实现,使用MMGP优化技术。profile=3配置适用于至少32GB内存和24GB显存的RTX 3090/4090环境,通过牺牲显存来适应有限的内存资源。使用量化模型:[480P模型](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors)[720P模型](https://huggingface.co/DeepBeepMeep/Wan2.1/blob/main/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors) |
| **LightX2V_5** | 使用SageAttention2替换原生注意力机制,采用DIT FP8+FP32(部分敏感层)混合精度计算,启用CPU offload技术,将部分敏感层执行FP32精度计算,将DIT推理过程中异步数据卸载到CPU上,节省显存,offload粒度为block级别 |
| **LightX2V_5-Distill** | 在LightX2V_5基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 |
| **LightX2V_6** | 在LightX2V_3基础上启用CPU offload技术,将部分敏感层执行FP32精度计算,将DIT推理过程中异步数据卸载到CPU上,节省显存,offload粒度为block级别 |
| **LightX2V_6-Distill** | 在LightX2V_6基础上使用4步蒸馏模型(`infer_steps=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量 |
---
*即将更新...* ## 📁 配置文件参考
### 720P 5s视频 基准测试相关的配置文件和运行脚本可在以下位置获取:
*即将更新...* | 类型 | 链接 | 说明 |
|:-----|:-----|:-----|
| **配置文件** | [configs/bench](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) | 包含各种优化配置的JSON文件 |
| **运行脚本** | [scripts/bench](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) | 包含基准测试的执行脚本 |
--- ---
## 表格说明 > 💡 **提示**: 建议根据您的硬件配置选择合适的优化方案,以获得最佳的性能表现。
- **Wan2.1 Official**: 基于[Wan2.1官方仓库](https://github.com/Wan-Video/Wan2.1)
- **FastVideo**: 基于[FastVideo官方仓库](https://github.com/hao-ai-lab/FastVideo),使用SageAttention后端
- **LightX2V_1**: 使用SageAttention2替换原生注意力机制,采用DIT BF16+FP32(部分敏感层)混合精度计算,在保持精度的同时提升计算效率
- **LightX2V_2**: 统一使用BF16精度计算,进一步减少显存占用和计算开销,同时保持生成质量
- **LightX2V_3**: 引入FP8量化技术显著减少计算精度要求,结合Tiling VAE技术优化显存使用
- **LightX2V_3-Distill**: 在LightX2V_3基础上使用4步蒸馏模型(`infer_step=4`, `enable_cfg=False`),进一步减少推理步数并保持生成质量。
- **LightX2V_4**: 在LightX2V_3基础上加入TeaCache(teacache_thresh=0.2)缓存复用技术,通过智能跳过冗余计算实现加速
- **配置文件参考**: 基准测试相关的配置文件和运行脚本可在以下位置获取:
- [配置文件](https://github.com/ModelTC/LightX2V/tree/main/configs/bench) - 包含各种优化配置的JSON文件
- [运行脚本](https://github.com/ModelTC/LightX2V/tree/main/scripts/bench) - 包含基准测试的执行脚本
# 快速入门 # LightX2V 快速入门指南
## 准备环境 欢迎使用 LightX2V!本指南将帮助您快速搭建环境并开始使用 LightX2V 进行视频生成。
我们推荐使用docker环境,这是lightx2v的[dockerhub](https://hub.docker.com/r/lightx2v/lightx2v/tags),请选择一个最新日期的tag,比如25061301 ## 📋 目录
```shell - [系统要求](#系统要求)
- [Linux 系统环境搭建](#linux-系统环境搭建)
- [Docker 环境(推荐)](#docker-环境推荐)
- [Conda 环境搭建](#conda-环境搭建)
- [Windows 系统环境搭建](#windows-系统环境搭建)
- [推理使用](#推理使用)
## 🚀 系统要求
- **操作系统**: Linux (Ubuntu 18.04+) 或 Windows 10/11
- **Python**: 3.10 或更高版本
- **GPU**: NVIDIA GPU,支持 CUDA,至少 8GB 显存
- **内存**: 建议 16GB 或更多
- **存储**: 至少 50GB 可用空间
## 🐧 Linux 系统环境搭建
### 🐳 Docker 环境(推荐)
我们强烈推荐使用 Docker 环境,这是最简单快捷的安装方式。
#### 1. 拉取镜像
访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25061301`
```bash
# 拉取最新版本的 LightX2V 镜像
docker pull lightx2v/lightx2v:25061301 docker pull lightx2v/lightx2v:25061301
docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --entrypoint /bin/bash [镜像id]
``` ```
对于中国大陆地区,若拉取镜像的时候,网络不稳定,可以从[渡渡鸟](https://docker.aityp.com/r/docker.io/lightx2v/lightx2v)上拉取 #### 2. 运行容器
```bash
docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --entrypoint /bin/bash [镜像id]
```
#### 3. 国内镜像源(可选)
对于中国大陆地区,如果拉取镜像时网络不稳定,可以从[渡渡鸟](https://docker.aityp.com/r/docker.io/lightx2v/lightx2v)上拉取:
```shell ```bash
docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/lightx2v/lightx2v:25061301 docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/lightx2v/lightx2v:25061301
``` ```
### 🐍 Conda 环境搭建
如果您希望使用 Conda 自行搭建环境,请按照以下步骤操作:
#### 步骤 1: 克隆项目
```bash
# 下载项目代码
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
#### 步骤 2: 创建 conda 虚拟环境
如果你想使用conda自己搭建环境,可以参考如下步骤: ```bash
# 创建并激活 conda 环境
conda create -n lightx2v python=3.12 -y
conda activate lightx2v
```
```shell #### 步骤 3: 安装依赖
# 下载github代码
git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v
conda create -n lightx2v python=3.11 && conda activate lightx2v ```bash
# 安装基础依赖
pip install -r requirements.txt pip install -r requirements.txt
```
> 💡 **提示**: 混元模型需要在 4.45.2 版本的 transformers 下运行,如果您不需要运行混元模型,可以跳过 transformers 版本限制。
# 混元模型需要在4.45.2版本的transformers下运行,如果不需要跑混元模型,可以忽略 #### 步骤 4: 安装注意力机制算子
# pip install transformers==4.45.2
# 安装 flash-attention 2 **选项 A: Flash Attention 2**
```bash
git clone https://github.com/Dao-AILab/flash-attention.git --recursive git clone https://github.com/Dao-AILab/flash-attention.git --recursive
cd flash-attention && python setup.py install cd flash-attention && python setup.py install
```
# 安装 flash-attention 3, 用于 hopper 显卡 **选项 B: Flash Attention 3(用于 Hopper 架构显卡)**
```bash
cd flash-attention/hopper && python setup.py install cd flash-attention/hopper && python setup.py install
``` ```
## 推理 **选项 C: SageAttention 2(推荐)**
```bash
git clone https://github.com/thu-ml/SageAttention.git
cd SageAttention && python setup.py install
```
## 🪟 Windows 系统环境搭建
Windows 系统仅支持 Conda 环境搭建方式,请按照以下步骤操作:
### 🐍 Conda 环境搭建
#### 步骤 1: 检查 CUDA 版本
首先确认您的 GPU 驱动和 CUDA 版本:
```cmd
nvidia-smi
```
记录输出中的 **CUDA Version** 信息,后续安装时需要保持版本一致。
#### 步骤 2: 创建 Python 环境
```cmd
# 创建新环境(推荐 Python 3.12)
conda create -n lightx2v python=3.12 -y
# 激活环境
conda activate lightx2v
```
> 💡 **提示**: 建议使用 Python 3.10 或更高版本以获得最佳兼容性。
#### 步骤 3: 安装 PyTorch 框架
**方法一:下载官方 wheel 包(推荐)**
1. 访问 [PyTorch 官方下载页面](https://download.pytorch.org/whl/torch/)
2. 选择对应版本的 wheel 包,注意匹配:
- **Python 版本**: 与您的环境一致
- **CUDA 版本**: 与您的 GPU 驱动匹配
- **平台**: 选择 Windows 版本
**示例(Python 3.12 + PyTorch 2.6 + CUDA 12.4):**
```cmd
# 下载并安装 PyTorch
pip install torch-2.6.0+cu124-cp312-cp312-win_amd64.whl
```shell # 安装配套包
# 修改脚本中的路径 pip install torchvision==0.21.0 torchaudio==2.6.0
```
**方法二:使用 pip 直接安装**
```cmd
# CUDA 12.4 版本示例
pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
```
#### 步骤 4: 安装 Windows 版 vLLM
[vllm-windows releases](https://github.com/SystemPanic/vllm-windows/releases) 下载对应的 wheel 包。
**版本匹配要求:**
- Python 版本匹配
- PyTorch 版本匹配
- CUDA 版本匹配
```cmd
# 安装 vLLM(请根据实际文件名调整)
pip install vllm-0.9.1+cu124-cp312-cp312-win_amd64.whl
```
#### 步骤 5: 安装注意力机制算子
**选项 A: Flash Attention 2**
```cmd
pip install flash-attn==2.7.2.post1
```
**选项 B: SageAttention 2(强烈推荐)**
**下载源:**
- [Windows 专用版本 1](https://github.com/woct0rdho/SageAttention/releases)
- [Windows 专用版本 2](https://github.com/sdbds/SageAttention-for-windows/releases)
```cmd
# 安装 SageAttention(请根据实际文件名调整)
pip install sageattention-2.1.1+cu126torch2.6.0-cp312-cp312-win_amd64.whl
```
> ⚠️ **注意**: SageAttention 的 CUDA 版本可以不严格对齐,但 Python 和 PyTorch 版本必须匹配。
#### 步骤 6: 克隆项目
```cmd
# 克隆项目代码
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
# 安装 Windows 专用依赖
pip install -r requirements_win.txt
```
## 🎯 推理使用
### 📥 模型准备
在开始推理之前,您需要提前下载好模型文件。我们推荐:
- **下载源**: 从 [LightX2V 官方 Hugging Face](https://huggingface.co/lightx2v/)或者其他开源模型库下载模型
- **存储位置**: 建议将模型存储在 SSD 磁盘上以获得更好的读取性能
- **可用模型**: 包括 Wan2.1-I2V、Wan2.1-T2V 等多种模型,支持不同分辨率和功能
### 📁 配置文件与脚本
推理会用到的配置文件都在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs),脚本都在[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts)
需要将下载的模型路径配置到运行脚本中。除了脚本中的输入参数,`--config_json` 指向的配置文件中也会包含一些必要参数,您可以根据需要自行修改。
### 🚀 开始推理
#### Linux 环境
```bash
# 修改脚本中的路径后运行
bash scripts/wan/run_wan_t2v.sh bash scripts/wan/run_wan_t2v.sh
``` ```
除了脚本中已有的输入参数,`--config_json`指向的`wan_t2v.json`中也会存在一些必要的参数,可以根据需要,自行修改。 #### Windows 环境
```cmd
# 使用 Windows 批处理脚本
scripts\win\run_wan_t2v.bat
```
## 📞 获取帮助
如果您在安装或使用过程中遇到问题,请:
1.[GitHub Issues](https://github.com/ModelTC/LightX2V/issues) 中搜索相关问题
2. 提交新的 Issue 描述您的问题
---
🎉 **恭喜!** 现在您已经成功搭建了 LightX2V 环境,可以开始享受视频生成的乐趣了!
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
### 分阶段去噪策略 ### 分阶段去噪策略
变分辨率推理基于以下观察: 变分辨率推理基于以下观察:
- **前期去噪**:主要处理粗糙的噪声和整体结构,不需要过多的细节信息 - **前期去噪**:主要处理粗糙的噪声和整体结构,不需要过多的细节信息
- **后期去噪**:专注于细节优化和高频信息恢复,需要完整的分辨率信息 - **后期去噪**:专注于细节优化和高频信息恢复,需要完整的分辨率信息
...@@ -25,10 +26,39 @@ ...@@ -25,10 +26,39 @@
- 恢复细节信息,完成精细化处理 - 恢复细节信息,完成精细化处理
### U型分辨率策略
如果在刚开始的去噪步,降低分辨率,可能会导致最后的生成的视频和正常推理的生成的视频,整体差异偏大。因此可以采用U型的分辨率策略,即最一开始保持几步的原始分辨率,再降低分辨率推理。
## 使用方式 ## 使用方式
变分辨率推理的config文件在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution) 变分辨率推理的config文件在[这里](https://github.com/ModelTC/LightX2V/tree/main/configs/changing_resolution)
通过指定--config_json到具体的config文件,即可以测试变分辨率推理。 通过指定--config_json到具体的config文件,即可以测试变分辨率推理。
可以参考[该脚本](https://github.com/ModelTC/LightX2V/blob/main/scripts/wan/run_wan_t2v_changing_resolution.sh) 可以参考[这里](https://github.com/ModelTC/LightX2V/blob/main/scripts/changing_resolution)的脚本运行。
### 举例1:
```
{
"infer_steps": 50,
"changing_resolution": true,
"resolution_rate": [0.75],
"changing_resolution_steps": [25]
}
```
表示总步数为50,0到25步的分辨率为0.75倍原始分辨率,26到最后一步的分辨率为原始分辨率。
### 举例2:
```
{
"infer_steps": 50,
"changing_resolution": true,
"resolution_rate": [1.0, 0.75],
"changing_resolution_steps": [10, 35]
}
```
表示总步数为50,0到10步的分辨率为原始分辨率,11到35步的分辨率为0.75倍原始分辨率,36到最后一步的分辨率为原始分辨率。
...@@ -25,6 +25,11 @@ try: ...@@ -25,6 +25,11 @@ try:
except ImportError: except ImportError:
deep_gemm = None deep_gemm = None
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
...@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# ========================= # =========================
# act quant kernels # act quant kernels
# ========================= # =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_vllm(self, x): def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
...@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops
try:
from vllm import _custom_ops as ops
except ModuleNotFoundError:
ops = None
class QuantLinearInt8(nn.Module): try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module): ...@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
return self return self
class QuantLinearFp8(nn.Module): class VllmQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module): ...@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale) self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias) self.bias = maybe_cast(self.bias)
return self return self
class TorchaoQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
...@@ -9,7 +9,7 @@ import torch.nn.functional as F ...@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
__all__ = [ __all__ = [
...@@ -83,9 +83,11 @@ class T5Attention(nn.Module): ...@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module): ...@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
# layers # layers
......
...@@ -10,7 +10,7 @@ import torchvision.transforms as T ...@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from transformers import CLIPVisionModel from transformers import CLIPVisionModel
...@@ -63,9 +63,11 @@ class SelfAttention(nn.Module): ...@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
# layers # layers
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module): ...@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
# layers # layers
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
......
...@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel): class WanDistillModel(WanModel):
...@@ -31,7 +32,9 @@ class WanDistillModel(WanModel): ...@@ -31,7 +32,9 @@ class WanDistillModel(WanModel):
return weight_dict return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt") ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = { weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys() key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
......
...@@ -30,9 +30,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -30,9 +30,7 @@ class WanAudioPreInfer(WanPreInfer):
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0) hidden_states = self.scheduler.latents.unsqueeze(0)
# hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1) hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1)
# print(f"{prev_mask.shape}, {hidden_states.shape}, {prev_latents.shape},{prev_latents[:, :, :ltnt_frames].shape}")
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents[:, :, :ltnt_frames]], dim=1)
hidden_states = hidden_states.squeeze(0) hidden_states = hidden_states.squeeze(0)
x = [hidden_states] x = [hidden_states]
......
...@@ -7,6 +7,7 @@ from lightx2v.utils.envs import * ...@@ -7,6 +7,7 @@ from lightx2v.utils.envs import *
class WanPreInfer: class WanPreInfer:
def __init__(self, config): def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"] self.task = config["task"]
...@@ -28,7 +29,7 @@ class WanPreInfer: ...@@ -28,7 +29,7 @@ class WanPreInfer:
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0): def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
x = [self.scheduler.latents] x = self.scheduler.latents
if self.scheduler.flag_df: if self.scheduler.flag_df:
t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0) t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0)
...@@ -40,27 +41,28 @@ class WanPreInfer: ...@@ -40,27 +41,28 @@ class WanPreInfer:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
if self.task == "i2v": if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
image_encoder = inputs["image_encoder_output"]["vae_encode_out"] if self.config.get("changing_resolution", False):
image_encoder = inputs["image_encoder_output"]["vae_encode_out"][self.scheduler.changing_resolution_index]
else:
image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2) frame_seq_length = (image_encoder.size(2) // 2) * (image_encoder.size(3) // 2)
if kv_end - kv_start >= frame_seq_length: # 如果是CausalVid, image_encoder取片段 if kv_end - kv_start >= frame_seq_length: # 如果是CausalVid, image_encoder取片段
idx_s = kv_start // frame_seq_length idx_s = kv_start // frame_seq_length
idx_e = kv_end // frame_seq_length idx_e = kv_end // frame_seq_length
image_encoder = image_encoder[:, idx_s:idx_e, :, :] image_encoder = image_encoder[:, idx_s:idx_e, :, :]
y = [image_encoder] y = image_encoder
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = torch.cat([x, y], dim=0)
# embeddings # embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x] x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0)
x = [u.flatten(2).transpose(1, 2) for u in x] x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda() seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg: if self.enable_dynamic_cfg:
......
...@@ -140,7 +140,7 @@ class DefaultRunner(BaseRunner): ...@@ -140,7 +140,7 @@ class DefaultRunner(BaseRunner):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img) clip_encoder_out = self.run_image_encoder(img)
vae_encode_out, kwargs = self.run_vae_encoder(img) vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img) text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -158,7 +158,7 @@ class DefaultRunner(BaseRunner): ...@@ -158,7 +158,7 @@ class DefaultRunner(BaseRunner):
} }
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
def _run_dit_local(self, kwargs): def _run_dit_local(self):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
...@@ -205,9 +205,9 @@ class DefaultRunner(BaseRunner): ...@@ -205,9 +205,9 @@ class DefaultRunner(BaseRunner):
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
kwargs = self.set_target_shape() self.set_target_shape()
latents, generator = self.run_dit(kwargs) latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator) images = self.run_vae_decoder(latents, generator)
......
...@@ -133,11 +133,31 @@ def adaptive_resize(img): ...@@ -133,11 +133,31 @@ def adaptive_resize(img):
def array_to_video( def array_to_video(
image_array: np.ndarray, image_array: np.ndarray,
output_path: str, output_path: str,
fps: Union[int, float] = 30, fps: int | float = 30,
resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None, resolution: tuple[int, int] | tuple[float, float] | None = None,
disable_log: bool = False, disable_log: bool = False,
lossless: bool = True, lossless: bool = True,
output_pix_fmt: str = "yuv420p",
) -> None: ) -> None:
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
output_pix_fmt (str): output pix_fmt in ffmpeg command.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if not isinstance(image_array, np.ndarray): if not isinstance(image_array, np.ndarray):
raise TypeError("Input should be np.ndarray.") raise TypeError("Input should be np.ndarray.")
assert image_array.ndim == 4 assert image_array.ndim == 4
...@@ -175,6 +195,7 @@ def array_to_video( ...@@ -175,6 +195,7 @@ def array_to_video(
output_path, output_path,
] ]
else: else:
output_pix_fmt = output_pix_fmt or "yuv420p"
command = [ command = [
"/usr/bin/ffmpeg", "/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists "-y", # (optional) overwrite output file if it exists
...@@ -194,10 +215,15 @@ def array_to_video( ...@@ -194,10 +215,15 @@ def array_to_video(
"-", # The input comes from a pipe "-", # The input comes from a pipe
"-vcodec", "-vcodec",
"libx264", "libx264",
"-pix_fmt",
f"{output_pix_fmt}",
"-an", # Tells FFMPEG not to expect any audio "-an", # Tells FFMPEG not to expect any audio
output_path, output_path,
] ]
if output_pix_fmt is not None:
command += ["-pix_fmt", output_pix_fmt]
if not disable_log: if not disable_log:
print(f'Running "{" ".join(command)}"') print(f'Running "{" ".join(command)}"')
process = subprocess.Popen( process = subprocess.Popen(
...@@ -260,7 +286,7 @@ def save_to_video(gen_lvideo, out_path, target_fps): ...@@ -260,7 +286,7 @@ def save_to_video(gen_lvideo, out_path, target_fps):
gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8) gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
gen_lvideo = gen_lvideo[..., ::-1].copy() gen_lvideo = gen_lvideo[..., ::-1].copy()
generate_unique_path(out_path) generate_unique_path(out_path)
array_to_video(gen_lvideo, output_path=out_path, fps=target_fps, lossless=False) array_to_video(gen_lvideo, output_path=out_path, fps=target_fps, lossless=False, output_pix_fmt="yuv444p")
def save_audio( def save_audio(
...@@ -474,8 +500,9 @@ class WanAudioRunner(WanRunner): ...@@ -474,8 +500,9 @@ class WanAudioRunner(WanRunner):
vae_dtype = torch.float vae_dtype = torch.float
for idx in range(interval_num): for idx in range(interval_num):
torch.manual_seed(42 + idx) self.config.seed = self.config.seed + idx
logger.info(f"### manual_seed: {42 + idx} ####") torch.manual_seed(self.config.seed)
logger.info(f"### manual_seed: {self.config.seed} ####")
useful_length = -1 useful_length = -1
if idx == 0: # 第一段 Condition padding0 if idx == 0: # 第一段 Condition padding0
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
...@@ -531,8 +558,9 @@ class WanAudioRunner(WanRunner): ...@@ -531,8 +558,9 @@ class WanAudioRunner(WanRunner):
ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape
# bs = 1 # bs = 1
frames_n = (nframe - 1) * 4 + 1 frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.zeros((1, frames_n, height, width), device=device, dtype=dtype) prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask[:, prev_len:] = 0 prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = wan_mask_rearrange(prev_mask).unsqueeze(0) prev_mask = wan_mask_rearrange(prev_mask).unsqueeze(0)
previmg_encoder_output = { previmg_encoder_output = {
"prev_latents": prev_latents, "prev_latents": prev_latents,
......
...@@ -57,11 +57,12 @@ class WanRunner(DefaultRunner): ...@@ -57,11 +57,12 @@ class WanRunner(DefaultRunner):
if clip_quantized: if clip_quantized:
clip_quant_scheme = self.config.get("clip_quant_scheme", None) clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_quantized_ckpt = self.config.get( clip_quantized_ckpt = self.config.get(
"clip_quantized_ckpt", "clip_quantized_ckpt",
os.path.join( os.path.join(
os.path.join(self.config.model_path, clip_quant_scheme), os.path.join(self.config.model_path, tmp_clip_quant_scheme),
f"clip-{clip_quant_scheme}.pth", f"clip-{tmp_clip_quant_scheme}.pth",
), ),
) )
else: else:
...@@ -93,12 +94,13 @@ class WanRunner(DefaultRunner): ...@@ -93,12 +94,13 @@ class WanRunner(DefaultRunner):
t5_quantized = self.config.get("t5_quantized", False) t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized: if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None) t5_quant_scheme = self.config.get("t5_quant_scheme", None)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
assert t5_quant_scheme is not None assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get( t5_quantized_ckpt = self.config.get(
"t5_quantized_ckpt", "t5_quantized_ckpt",
os.path.join( os.path.join(
os.path.join(self.config.model_path, t5_quant_scheme), os.path.join(self.config.model_path, tmp_t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth", f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth",
), ),
) )
else: else:
...@@ -202,19 +204,30 @@ class WanRunner(DefaultRunner): ...@@ -202,19 +204,30 @@ class WanRunner(DefaultRunner):
return clip_encoder_out return clip_encoder_out
def run_vae_encoder(self, img): def run_vae_encoder(self, img):
kwargs = {}
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1]) lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2]) lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
if self.config.get("changing_resolution", False):
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])):
lat_h, lat_w = int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2, int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2
vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
return vae_encode_out_list
else:
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
return vae_encode_out
def get_vae_encoder_output(self, img, lat_h, lat_w):
h = lat_h * self.config.vae_stride[1] h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2] w = lat_w * self.config.vae_stride[2]
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones( msk = torch.ones(
1, 1,
self.config.target_video_length, self.config.target_video_length,
...@@ -245,7 +258,7 @@ class WanRunner(DefaultRunner): ...@@ -245,7 +258,7 @@ class WanRunner(DefaultRunner):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return vae_encode_out, kwargs return vae_encode_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = { image_encoder_output = {
...@@ -258,7 +271,6 @@ class WanRunner(DefaultRunner): ...@@ -258,7 +271,6 @@ class WanRunner(DefaultRunner):
} }
def set_target_shape(self): def set_target_shape(self):
ret = {}
num_channels_latents = self.config.get("num_channels_latents", 16) num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = ( self.config.target_shape = (
...@@ -267,8 +279,6 @@ class WanRunner(DefaultRunner): ...@@ -267,8 +279,6 @@ class WanRunner(DefaultRunner):
self.config.lat_h, self.config.lat_h,
self.config.lat_w, self.config.lat_w,
) )
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
elif self.config.task == "t2v": elif self.config.task == "t2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
...@@ -276,8 +286,6 @@ class WanRunner(DefaultRunner): ...@@ -276,8 +286,6 @@ class WanRunner(DefaultRunner):
int(self.config.target_height) // self.config.vae_stride[1], int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2], int(self.config.target_width) // self.config.vae_stride[2],
) )
ret["target_shape"] = self.config.target_shape
return ret
def save_video_func(self, images): def save_video_func(self, images):
cache_video( cache_video(
......
import os
import gc
import math
import numpy as np
import torch
from typing import List, Optional, Tuple, Union
from lightx2v.utils.envs import *
from diffusers.configuration_utils import register_to_config
from torch import Tensor
from .utils import unsqueeze_to_ndim
from diffusers import (
FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore
)
def get_timesteps(num_steps, max_steps: int = 1000):
return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32)
def timestep_shift(timesteps, shift: float = 1.0):
return shift * timesteps / (1 + (shift - 1) * timesteps)
class FlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteSchedulerBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.init_noise_sigma = 1.0
def add_noise(self, x0: Tensor, noise: Tensor, timesteps: Tensor):
dtype = x0.dtype
device = x0.device
sigma = timesteps.to(device, torch.float32) / self.config.num_train_timesteps
sigma = unsqueeze_to_ndim(sigma, x0.ndim)
xt = x0.float() * (1 - sigma) + noise.float() * sigma
return xt.to(dtype)
def get_velocity(self, x0: Tensor, noise: Tensor, timesteps: Tensor | None = None):
return noise - x0
def velocity_loss_to_x_loss(self, v_loss: Tensor, timesteps: Tensor):
device = v_loss.device
sigma = timesteps.to(device, torch.float32) / self.config.num_train_timesteps
return v_loss.float() * (sigma**2)
class EulerSchedulerTimestepFix(FlowMatchEulerDiscreteScheduler):
def __init__(self, config):
self.config = config
self.step_index = 0
self.latents = None
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.noise_pred = None
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
def set_shift(self, shift: float = 1.0):
self.sigmas = self.timesteps_ori / self.num_train_timesteps
self.sigmas = timestep_shift(self.sigmas, shift=shift)
self.timesteps = self.sigmas * self.num_train_timesteps
def set_timesteps(
self,
infer_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
timesteps = get_timesteps(num_steps=infer_steps, max_steps=self.num_train_timesteps)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device or self.device)
self.timesteps_ori = self.timesteps.clone()
self.set_shift(self.sample_shift)
self._step_index = None
self._begin_index = None
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if os.path.isfile(self.config.image_path):
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
else:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.set_timesteps(infer_steps=self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
timestep = self.timesteps[self.step_index]
sample = self.latents.to(torch.float32)
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32) # pyright: ignore
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim)
# x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output
self._step_index += 1
return x_t_next
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment