Commit 4796fc6e authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #104 from ModelTC/dev_gradio

Update gradio
parents 82a99862 38abfd92
# Lightx2v Gradio Demo Interface
## 📖 Overview
Lightx2v is a lightweight video inference and generation engine that provides a web interface based on Gradio, supporting both Image-to-Video and Text-to-Video generation modes.
This project contains two main demo files:
- `gradio_demo.py` - English interface version
- `gradio_demo_zh.py` - Chinese interface version
## 🚀 Quick Start
### System Requirements
- Python 3.10+ (recommended)
- CUDA 12.4+ (recommended)
- At least 8GB GPU VRAM
- At least 16GB system memory
- At least 128GB SSD solid-state drive (**💾 Strongly recommend using SSD solid-state drives to store model files! During "lazy loading" startup, significantly improves model loading speed and inference performance**)
### Install Dependencies
```bash
# Install basic dependencies
pip install -r ../requirements.txt
pip install gradio
```
#### Recommended Optimization Library Configuration
-[Flash attention](https://github.com/Dao-AILab/flash-attention)
-[Sage attention](https://github.com/thu-ml/SageAttention)
-[vllm-kernel](https://github.com/vllm-project/vllm)
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (only supports ADA architecture GPUs)
### 🤖 Supported Models
#### 🎬 Image-to-Video Models
| Model Name | Resolution | Parameters | Features | Recommended Use |
|------------|------------|------------|----------|-----------------|
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 480p | 14B | Standard version | Balance speed and quality |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 720p | 14B | HD version | Pursue high-quality output |
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 480p | 14B | Distilled optimized version | Faster inference speed |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 720p | 14B | HD distilled version | High quality + fast inference |
#### 📝 Text-to-Video Models
| Model Name | Parameters | Features | Recommended Use |
|------------|------------|----------|-----------------|
| ✅ [Wan2.1-T2V-1.3B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 1.3B | Lightweight | Fast prototyping and testing |
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 14B | Standard version | Balance speed and quality |
| ✅ [Wan2.1-T2V-14B-Lightx2v-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 14B | Distilled optimized version | High quality + fast inference |
**💡 Model Selection Recommendations**:
- **First-time use**: Recommend choosing distilled versions
- **Pursuing quality**: Choose 720p resolution or 14B parameter models
- **Pursuing speed**: Choose 480p resolution or 1.3B parameter models
- **Resource-constrained**: Prioritize distilled versions and lower resolutions
### Startup Methods
#### Method 1: Using Startup Script (Recommended)
```bash
# 1. Edit the startup script to configure relevant paths
vim run_gradio.sh
# Configuration items that need to be modified:
# - lightx2v_path: Lightx2v project root directory path
# - i2v_model_path: Image-to-video model path
# - t2v_model_path: Text-to-video model path
# 💾 Important note: Recommend pointing model paths to SSD storage locations
# Example: /mnt/ssd/models/ or /data/ssd/models/
# 2. Run the startup script
bash run_gradio.sh
# 3. Or start with parameters (recommended)
bash run_gradio.sh --task i2v --lang en --port 8032
# bash run_gradio.sh --task t2v --lang en --port 8032
```
#### Method 2: Direct Command Line Startup
**Image-to-Video Mode:**
```bash
python gradio_demo.py \
--model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \
--task i2v \
--server_name 0.0.0.0 \
--server_port 7862
```
**Text-to-Video Mode:**
```bash
python gradio_demo.py \
--model_path /path/to/Wan2.1-T2V-1.3B \
--task t2v \
--server_name 0.0.0.0 \
--server_port 7862
```
**Chinese Interface Version:**
```bash
python gradio_demo_zh.py \
--model_path /path/to/model \
--task i2v \
--server_name 0.0.0.0 \
--server_port 7862
```
## 📋 Command Line Parameters
| Parameter | Type | Required | Default | Description |
|-----------|------|----------|---------|-------------|
| `--model_path` | str | ✅ | - | Model folder path |
| `--model_cls` | str | ❌ | wan2.1 | Model class (currently only supports wan2.1) |
| `--task` | str | ✅ | - | Task type: `i2v` (image-to-video) or `t2v` (text-to-video) |
| `--server_port` | int | ❌ | 7862 | Server port |
| `--server_name` | str | ❌ | 0.0.0.0 | Server IP address |
## 🎯 Features
### Basic Settings
#### Model Type Selection
- **Wan2.1 14B**: Large parameter count, high generation quality, suitable for high-quality video generation
- **Wan2.1 1.3B**: Lightweight model, fast speed, suitable for rapid prototyping and testing
#### Input Parameters
- **Prompt**: Describe the expected video content
- **Negative Prompt**: Specify elements you don't want to appear
- **Resolution**: Supports multiple preset resolutions (480p/540p/720p)
- **Random Seed**: Controls the randomness of generation results
- **Inference Steps**: Affects the balance between generation quality and speed
#### Video Parameters
- **FPS**: Frames per second
- **Total Frames**: Video length
- **CFG Scale Factor**: Controls prompt influence strength (1-10)
- **Distribution Shift**: Controls generation style deviation degree (0-10)
### Advanced Optimization Options
#### GPU Memory Optimization
- **Chunked Rotary Position Embedding**: Saves GPU memory
- **Rotary Embedding Chunk Size**: Controls chunk granularity
- **Clean CUDA Cache**: Promptly frees GPU memory
#### Asynchronous Offloading
- **CPU Offloading**: Transfers partial computation to CPU
- **Lazy Loading**: Loads model components on-demand, significantly reduces system memory consumption
- **Offload Granularity Control**: Fine-grained control of offloading strategies
#### Low-Precision Quantization
- **Attention Operators**: Flash Attention, Sage Attention, etc.
- **Quantization Operators**: vLLM, SGL, Q8F, etc.
- **Precision Modes**: FP8, INT8, BF16, etc.
#### VAE Optimization
- **Lightweight VAE**: Accelerates decoding process
- **VAE Tiling Inference**: Reduces memory usage
#### Feature Caching
- **Tea Cache**: Caches intermediate features to accelerate generation
- **Cache Threshold**: Controls cache trigger conditions
- **Key Step Caching**: Writes cache only at key steps
## 🔧 Auto-Configuration Feature
After enabling "Auto-configure Inference Options", the system will automatically optimize parameters based on your hardware configuration:
### GPU Memory Rules
- **80GB+**: Default configuration, no optimization needed
- **48GB**: Enable CPU offloading, offload ratio 50%
- **40GB**: Enable CPU offloading, offload ratio 80%
- **32GB**: Enable CPU offloading, offload ratio 100%
- **24GB**: Enable BF16 precision, VAE tiling
- **16GB**: Enable chunked offloading, rotary embedding chunking
- **12GB**: Enable cache cleaning, lightweight VAE
- **8GB**: Enable quantization, lazy loading
### CPU Memory Rules
- **128GB+**: Default configuration
- **64GB**: Enable DIT quantization
- **32GB**: Enable lazy loading
- **16GB**: Enable full model quantization
## ⚠️ Important Notes
### 🚀 Low-Resource Device Optimization Recommendations
**💡 For devices with insufficient VRAM or performance constraints**:
- **🎯 Model Selection**: Prioritize using distilled version models (StepDistill-CfgDistill)
- **⚡ Inference Steps**: Recommend setting to 4 steps
- **🔧 CFG Settings**: Recommend disabling CFG option to improve generation speed
- **🔄 Auto-Configuration**: Enable "Auto-configure Inference Options"
### 🔧 Quick Optimization Configuration Examples
```bash
# Start with distilled model
bash run_gradio.sh --task i2v
# Interface setting recommendations
- Inference Steps: 25
- CFG Scale Factor: 4
- Resolution: 832x480
- Auto-Configuration: Enabled
- Quantization Scheme: int8
- Tea Cache: Enabled
```
## 📁 File Structure
```
lightx2v/app/
├── gradio_demo.py # English interface demo
├── gradio_demo_zh.py # Chinese interface demo
├── run_gradio.sh # Startup script
├── README.md # Documentation
├── saved_videos/ # Generated video save directory
└── inference_logs.log # Inference logs
```
## 🎨 Interface Description
### Basic Settings Tab
- **Input Parameters**: Model type, prompts, resolution, and other basic settings
- **Video Parameters**: FPS, frame count, CFG, and other video generation parameters
- **Output Settings**: Video save path configuration
### Advanced Options Tab
- **GPU Memory Optimization**: Memory management related options
- **Asynchronous Offloading**: CPU offloading and lazy loading
- **Low-Precision Quantization**: Various quantization optimization options
- **VAE Optimization**: Variational Autoencoder optimization
- **Feature Caching**: Cache strategy configuration
## 🔍 Troubleshooting
### Common Issues
**💡 Tip**: Generally, after enabling "Auto-configure Inference Options", the system will automatically optimize parameter settings based on your hardware configuration, and performance issues usually won't occur. If you encounter problems, please refer to the following solutions:
1. **CUDA Memory Insufficient**
- Enable CPU offloading
- Reduce resolution
- Enable quantization options
2. **System Memory Insufficient**
- Enable CPU offloading
- Enable lazy loading option
- Enable quantization options
3. **Slow Generation Speed**
- Reduce inference steps
- Enable auto-configuration
- Use lightweight models
- Enable Tea Cache
- Use quantization operators
- 💾 **Check if models are stored on SSD**
4. **Slow Model Loading**
- 💾 **Migrate models to SSD storage**
- Enable lazy loading option
- Check disk I/O performance
- Consider using NVMe SSD
5. **Poor Video Quality**
- Increase inference steps
- Increase CFG scale factor
- Use 14B models
- Optimize prompts
### Log Viewing
```bash
# View inference logs
tail -f inference_logs.log
# View GPU usage
nvidia-smi
# View system resources
htop
```
**Note**: Please comply with relevant laws and regulations when using videos generated by this tool, and do not use them for illegal purposes.
# Lightx2v Gradio 演示界面
## 📖 概述
Lightx2v 是一个轻量级的视频推理和生成引擎,提供了基于 Gradio 的 Web 界面,支持图像到视频(Image-to-Video)和文本到视频(Text-to-Video)两种生成模式。
本项目包含两个主要演示文件:
- `gradio_demo.py` - 英文界面版本
- `gradio_demo_zh.py` - 中文界面版本
## 🚀 快速开始
### 环境要求
- Python 3.10+ (推荐)
- CUDA 12.4+ (推荐)
- 至少 8GB GPU 显存
- 至少 16GB 系统内存
- 至少 128GB SSD固态硬盘 (**💾 强烈建议使用SSD固态硬盘存储模型文件!"延迟加载"启动时,显著提升模型加载速度和推理性能**)
### 安装依赖☀
```bash
# 安装基础依赖
pip install -r ../requirements.txt
pip install gradio
```
#### 推荐优化库配置
-[Flash attention](https://github.com/Dao-AILab/flash-attention)
-[Sage attention](https://github.com/thu-ml/SageAttention)
-[vllm-kernel](https://github.com/vllm-project/vllm)
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (只支持ADA架构的GPU)
### 🤖 支持的模型
#### 🎬 图像到视频模型 (Image-to-Video)
| 模型名称 | 分辨率 | 参数量 | 特点 | 推荐场景 |
|----------|--------|--------|------|----------|
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 480p | 14B | 标准版本 | 平衡速度和质量 |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 720p | 14B | 高清版本 | 追求高质量输出 |
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 480p | 14B | 蒸馏优化版 | 更快的推理速度 |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 720p | 14B | 高清蒸馏版 | 高质量+快速推理 |
#### 📝 文本到视频模型 (Text-to-Video)
| 模型名称 | 参数量 | 特点 | 推荐场景 |
|----------|--------|------|----------|
| ✅ [Wan2.1-T2V-1.3B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 1.3B | 轻量级 | 快速原型测试 |
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 14B | 标准版本 | 平衡速度和质量 |
| ✅ [Wan2.1-T2V-14B-Lightx2v-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) | 14B | 蒸馏优化版 | 高质量+快速推理 |
**💡 模型选择建议**:
- **首次使用**: 建议选择蒸馏版本
- **追求质量**: 选择720p分辨率或14B参数模型
- **追求速度**: 选择480p分辨率或1.3B参数模型
- **资源受限**: 优先选择蒸馏版本和较低分辨率
### 启动方式
#### 方式一:使用启动脚本(推荐)
```bash
# 1. 编辑启动脚本,配置相关路径
vim run_gradio.sh
# 需要修改的配置项:
# - lightx2v_path: Lightx2v项目根目录路径
# - i2v_model_path: 图像到视频模型路径
# - t2v_model_path: 文本到视频模型路径
# 💾 重要提示:建议将模型路径指向SSD存储位置
# 例如:/mnt/ssd/models/ 或 /data/ssd/models/
# 2. 运行启动脚本
bash run_gradio.sh
# 3. 或使用参数启动(推荐)
bash run_gradio.sh --task i2v --lang zh --port 8032
# bash run_gradio.sh --task t2v --lang zh --port 8032
```
#### 方式二:直接命令行启动
**图像到视频模式:**
```bash
python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \
--task i2v \
--server_name 0.0.0.0 \
--server_port 7862
```
**文本到视频模式:**
```bash
python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-T2V-1.3B \
--task t2v \
--server_name 0.0.0.0 \
--server_port 7862
```
**英文界面版本:**
```bash
python gradio_demo.py \
--model_path /path/to/model \
--task i2v \
--server_name 0.0.0.0 \
--server_port 7862
```
## 📋 命令行参数
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `--model_path` | str | ✅ | - | 模型文件夹路径 |
| `--model_cls` | str | ❌ | wan2.1 | 模型类别(目前仅支持wan2.1) |
| `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) |
| `--server_port` | int | ❌ | 7862 | 服务器端口 |
| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 |
## 🎯 功能特性
### 基本设置
#### 模型类型选择
- **Wan2.1 14B**: 参数量大,生成质量高,适合高质量视频生成
- **Wan2.1 1.3B**: 轻量级模型,速度快,适合快速原型和测试
#### 输入参数
- **提示词 (Prompt)**: 描述期望的视频内容
- **负向提示词 (Negative Prompt)**: 指定不希望出现的元素
- **分辨率**: 支持多种预设分辨率(480p/540p/720p)
- **随机种子**: 控制生成结果的随机性
- **推理步数**: 影响生成质量和速度的平衡
#### 视频参数
- **FPS**: 每秒帧数
- **总帧数**: 视频长度
- **CFG缩放因子**: 控制提示词影响强度(1-10)
- **分布偏移**: 控制生成风格偏离程度(0-10)
### 高级优化选项
#### GPU内存优化
- **分块旋转位置编码**: 节省GPU内存
- **旋转编码块大小**: 控制分块粒度
- **清理CUDA缓存**: 及时释放GPU内存
#### 异步卸载
- **CPU卸载**: 将部分计算转移到CPU
- **延迟加载**: 按需加载模型组件,显著节省系统内存消耗
- **卸载粒度控制**: 精细控制卸载策略
#### 低精度量化
- **注意力算子**: Flash Attention、Sage Attention等
- **量化算子**: vLLM、SGL、Q8F等
- **精度模式**: FP8、INT8、BF16等
#### VAE优化
- **轻量级VAE**: 加速解码过程
- **VAE分块推理**: 减少内存占用
#### 特征缓存
- **Tea Cache**: 缓存中间特征加速生成
- **缓存阈值**: 控制缓存触发条件
- **关键步缓存**: 仅在关键步骤写入缓存
## 🔧 自动配置功能
启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数:
### GPU内存规则
- **80GB+**: 默认配置,无需优化
- **48GB**: 启用CPU卸载,卸载比例50%
- **40GB**: 启用CPU卸载,卸载比例80%
- **32GB**: 启用CPU卸载,卸载比例100%
- **24GB**: 启用BF16精度、VAE分块
- **16GB**: 启用分块卸载、旋转编码分块
- **12GB**: 启用清理缓存、轻量级VAE
- **8GB**: 启用量化、延迟加载
### CPU内存规则
- **128GB+**: 默认配置
- **64GB**: 启用DIT量化
- **32GB**: 启用延迟加载
- **16GB**: 启用全模型量化
## ⚠️ 重要注意事项
### 🚀 低资源设备优化建议
**💡 针对显存不足或性能受限的设备**:
- **🎯 模型选择**: 优先使用蒸馏版本模型 (StepDistill-CfgDistill)
- **⚡ 推理步数**: 建议设置为 4 步
- **🔧 CFG设置**: 建议关闭CFG选项以提升生成速度
- **🔄 自动配置**: 启用"自动配置推理选项"
### 🔧 快速优化配置示例
```bash
# 启动时使用蒸馏模型
bash run_gradio.sh --task i2v
# 界面设置建议
- 推理步数: 25
- CFG缩放因子: 4
- 分辨率: 832x480
- 自动配置: 开启
- 量化方案: int8
- Tea Cache: 开启
```
## 📁 文件结构
```
lightx2v/app/
├── gradio_demo.py # 英文界面演示
├── gradio_demo_zh.py # 中文界面演示
├── run_gradio.sh # 启动脚本
├── README.md # 说明文档
├── saved_videos/ # 生成视频保存目录
└── inference_logs.log # 推理日志
```
## 🎨 界面说明
### 基本设置标签页
- **输入参数**: 模型类型、提示词、分辨率等基本设置
- **视频参数**: FPS、帧数、CFG等视频生成参数
- **输出设置**: 视频保存路径配置
### 高级选项标签页
- **GPU内存优化**: 内存管理相关选项
- **异步卸载**: CPU卸载和延迟加载
- **低精度量化**: 各种量化优化选项
- **VAE优化**: 变分自编码器优化
- **特征缓存**: 缓存策略配置
## 🔍 故障排除
### 常见问题
**💡 提示**: 一般情况下,启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数设置,通常不会出现性能问题。如果遇到问题,请参考以下解决方案:
1. **CUDA内存不足**
- 启用CPU卸载
- 降低分辨率
- 启用量化选项
1. **系統内存不足**
- 启用CPU卸载
- 启用延迟加载选项
- 启用量化选项
2. **生成速度慢**
- 减少推理步数
- 启用自动配置
- 使用轻量级模型
- 启用Tea Cache
- 使用量化算子
- 💾 **检查模型是否存放在SSD上**
3. **模型加载缓慢**
- 💾 **将模型迁移到SSD存储**
- 启用延迟加载选项
- 检查磁盘I/O性能
- 考虑使用NVMe SSD
4. **视频质量不佳**
- 增加推理步数
- 提高CFG缩放因子
- 使用14B模型
- 优化提示词
### 日志查看
```bash
# 查看推理日志
tail -f inference_logs.log
# 查看GPU使用情况
nvidia-smi
# 查看系统资源
htop
```
欢迎提交Issue和Pull Request来改进这个项目!
**注意**: 使用本工具生成的视频内容请遵守相关法律法规,不得用于非法用途。
......@@ -148,10 +148,8 @@ for op_name, is_installed in available_attn_ops:
def run_inference(
model_type,
task,
prompt,
negative_prompt,
image_path,
save_video_path,
torch_compile,
infer_steps,
......@@ -181,22 +179,18 @@ def run_inference(
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
image_path=None,
):
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path
global global_runner, current_config, model_path, task
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f)
if task == "Image to Video":
task = "i2v"
elif task == "Text to Video":
task = "t2v"
if task == "t2v":
if model_type == "Wan2.1 1.3B":
# 1.3B
......@@ -551,6 +545,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tiny_vae_val": True,
},
),
(
......@@ -569,6 +564,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"use_tiny_vae_val": True,
},
),
]
......@@ -606,6 +602,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"lazy_load_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tiny_vae_val": True,
}
if res == "540p"
else {
......@@ -619,11 +616,15 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"use_tiny_vae_val": True,
}
),
),
]
else:
gpu_rules = {}
if is_14b:
cpu_rules = [
(128, {}),
......@@ -639,6 +640,8 @@ def auto_configure(enable_auto_config, model_type, resolution):
},
),
]
else:
cpu_rules = {}
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
......@@ -654,12 +657,6 @@ def auto_configure(enable_auto_config, model_type, resolution):
def main():
def update_model_type(task_type):
if task_type == "Image to Video":
return gr.update(choices=["Wan2.1 14B"], value="Wan2.1 14B")
elif task_type == "Text to Video":
return gr.update(choices=["Wan2.1 14B", "Wan2.1 1.3B"], value="Wan2.1 14B")
def toggle_image_input(task):
return gr.update(visible=(task == "Image to Video"))
......@@ -684,36 +681,28 @@ def main():
gr.Markdown("## 📥 Input Parameters")
with gr.Row():
task = gr.Dropdown(
choices=["Image to Video", "Text to Video"],
value="Image to Video",
label="Task Type",
)
model_type = gr.Dropdown(
choices=["Wan2.1 14B"],
value="Wan2.1 14B",
label="Model Type",
)
task.change(
fn=update_model_type,
inputs=task,
outputs=model_type,
)
if task == "i2v":
model_type = gr.Dropdown(
choices=["Wan2.1 14B"],
value="Wan2.1 14B",
label="Model Type",
)
else:
model_type = gr.Dropdown(
choices=["Wan2.1 14B", "Wan2.1 1.3B"],
value="Wan2.1 14B",
label="Model Type",
)
with gr.Row():
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
visible=True, # Initially visible
)
task.change(
fn=toggle_image_input,
inputs=task,
outputs=image_path,
)
if task == "i2v":
with gr.Row():
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
visible=True,
)
with gr.Row():
with gr.Column():
......@@ -755,6 +744,13 @@ def main():
value="832x480",
label="Maximum Resolution",
)
with gr.Column():
enable_auto_config = gr.Checkbox(
label="Auto-configure Inference Options",
value=False,
info="Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.",
)
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
......@@ -836,14 +832,6 @@ def main():
with gr.Tab("⚙️ Advanced Options", id=2):
with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### Auto configuration")
with gr.Row():
enable_auto_config = gr.Checkbox(
label="Auto configuration",
value=False,
info="Auto-tune optimization settings for your GPU",
)
gr.Markdown("### GPU Memory Optimization")
with gr.Row():
rotary_chunk = gr.Checkbox(
......@@ -1007,47 +995,85 @@ def main():
use_ret_steps,
],
)
infer_btn.click(
fn=run_inference,
inputs=[
model_type,
task,
prompt,
negative_prompt,
image_path,
save_video_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
],
outputs=output_video,
)
if task == "i2v":
infer_btn.click(
fn=run_inference,
inputs=[
model_type,
prompt,
negative_prompt,
save_video_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
image_path,
],
outputs=output_video,
)
else:
infer_btn.click(
fn=run_inference,
inputs=[
model_type,
prompt,
negative_prompt,
save_video_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
],
outputs=output_video,
)
demo.launch(share=True, server_port=args.server_port, server_name=args.server_name)
......@@ -1062,6 +1088,7 @@ if __name__ == "__main__":
default="wan2.1",
help="Model class to use",
)
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="Specify the task type. 'i2v' for image-to-video translation, 't2v' for text-to-video generation.")
parser.add_argument("--server_port", type=int, default=7862, help="Server port")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server ip")
args = parser.parse_args()
......@@ -1069,5 +1096,6 @@ if __name__ == "__main__":
global model_path, model_cls
model_path = args.model_path
model_cls = args.model_cls
task = args.task
main()
......@@ -13,7 +13,6 @@ import importlib.util
import psutil
import random
logger.add(
"inference_logs.log",
rotation="100 MB",
......@@ -98,7 +97,7 @@ def get_gpu_memory(gpu_idx=0):
try:
with torch.cuda.device(gpu_idx):
memory_info = torch.cuda.mem_get_info()
total_memory = memory_info[1] / (1024**3)
total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
return total_memory
except Exception as e:
logger.warning(f"获取GPU内存失败: {e}")
......@@ -149,10 +148,8 @@ for op_name, is_installed in available_attn_ops:
def run_inference(
model_type,
task,
prompt,
negative_prompt,
image_path,
save_video_path,
torch_compile,
infer_steps,
......@@ -182,22 +179,18 @@ def run_inference(
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
image_path=None,
):
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path
global global_runner, current_config, model_path, task
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f)
if task == "图像生成视频":
task = "i2v"
elif task == "文本生成视频":
task = "t2v"
if task == "t2v":
if model_type == "Wan2.1 1.3B":
# 1.3B
......@@ -407,6 +400,7 @@ def run_inference(
logger.info(f"使用模型: {model_path}")
logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
# Initialize or reuse the runner
runner = global_runner
if needs_reinit:
if runner is not None:
......@@ -551,6 +545,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tiny_vae_val": True,
},
),
(
......@@ -569,6 +564,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"use_tiny_vae_val": True,
},
),
]
......@@ -606,6 +602,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"lazy_load_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tiny_vae_val": True,
}
if res == "540p"
else {
......@@ -619,11 +616,15 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"use_tiny_vae_val": True,
}
),
),
]
else:
gpu_rules = {}
if is_14b:
cpu_rules = [
(128, {}),
......@@ -639,6 +640,8 @@ def auto_configure(enable_auto_config, model_type, resolution):
},
),
]
else:
cpu_rules = {}
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
......@@ -654,17 +657,11 @@ def auto_configure(enable_auto_config, model_type, resolution):
def main():
def update_model_type(task_type):
if task_type == "图像生成视频":
return gr.update(choices=["Wan2.1 14B"], value="Wan2.1 14B")
elif task_type == "文本生成视频":
return gr.update(choices=["Wan2.1 14B", "Wan2.1 1.3B"], value="Wan2.1 14B")
def toggle_image_input(task):
return gr.update(visible=(task == "图像生成视频"))
return gr.update(visible=(task == "i2v"))
with gr.Blocks(
title="Lightx2v (轻量级视频生成推理引擎)",
title="Lightx2v (轻量级视频推理和生成引擎)",
css="""
.main-content { max-width: 1400px; margin: auto; }
.output-video { max-height: 650px; }
......@@ -684,36 +681,28 @@ def main():
gr.Markdown("## 📥 输入参数")
with gr.Row():
task = gr.Dropdown(
choices=["图像生成视频", "文本生成视频"],
value="图像生成视频",
label="任务类型",
)
model_type = gr.Dropdown(
choices=["Wan2.1 14B"],
value="Wan2.1 14B",
label="模型类型",
)
task.change(
fn=update_model_type,
inputs=task,
outputs=model_type,
)
if task == "i2v":
model_type = gr.Dropdown(
choices=["Wan2.1 14B"],
value="Wan2.1 14B",
label="模型类型",
)
else:
model_type = gr.Dropdown(
choices=["Wan2.1 14B", "Wan2.1 1.3B"],
value="Wan2.1 14B",
label="模型类型",
)
with gr.Row():
image_path = gr.Image(
label="输入图像",
type="filepath",
height=300,
interactive=True,
visible=True,
)
task.change(
fn=toggle_image_input,
inputs=task,
outputs=image_path,
)
if task == "i2v":
with gr.Row():
image_path = gr.Image(
label="输入图像",
type="filepath",
height=300,
interactive=True,
visible=True,
)
with gr.Row():
with gr.Column():
......@@ -755,6 +744,11 @@ def main():
value="832x480",
label="最大分辨率",
)
with gr.Column():
enable_auto_config = gr.Checkbox(
label="自动配置推理选项", value=False, info="自动优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。"
)
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
......@@ -764,9 +758,10 @@ def main():
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 生成随机种子", variant="secondary")
randomize_btn = gr.Button("🎲 随机化", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
infer_steps = gr.Slider(
label="推理步数",
......@@ -774,7 +769,7 @@ def main():
maximum=100,
step=1,
value=40,
info="视频生成的推理步数。增加步数可能提高质量但降低速度",
info="视频生成的推理步数。增加步数可能提高质量但降低速度",
)
enable_cfg = gr.Checkbox(
......@@ -788,7 +783,7 @@ def main():
maximum=10,
step=1,
value=5,
info="控制提示词的影响强度。值越高,提示词的影响越大",
info="控制提示词的影响强度。值越高,提示词的影响越大",
)
sample_shift = gr.Slider(
label="分布偏移",
......@@ -796,7 +791,7 @@ def main():
minimum=0,
maximum=10,
step=1,
info="控制样本分布偏移的程度。值越大表示偏移越明显",
info="控制样本分布偏移的程度。值越大表示偏移越明显",
)
fps = gr.Slider(
......@@ -805,7 +800,7 @@ def main():
maximum=30,
step=1,
value=16,
info="视频的每秒帧数。较高的FPS会产生更流畅的视频",
info="视频的每秒帧数。较高的FPS会产生更流畅的视频",
)
num_frames = gr.Slider(
label="总帧数",
......@@ -813,7 +808,7 @@ def main():
maximum=120,
step=1,
value=81,
info="视频中的总帧数。更多帧数会产生更长的视频",
info="视频中的总帧数。更多帧数会产生更长的视频",
)
save_video_path = gr.Textbox(
......@@ -835,14 +830,6 @@ def main():
with gr.Tab("⚙️ 高级选项", id=2):
with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### 自动配置")
with gr.Row():
enable_auto_config = gr.Checkbox(
label="自动配置",
value=False,
info="自动调整优化设置以适应您的GPU",
)
gr.Markdown("### GPU内存优化")
with gr.Row():
rotary_chunk = gr.Checkbox(
......@@ -857,13 +844,13 @@ def main():
minimum=100,
maximum=10000,
step=100,
info="控制应用旋转编码的块大小, 较大的值可能提高性能但增加内存使用, 仅在'rotary_chunk'勾选时有效",
info="控制应用旋转编码的块大小较大的值可能提高性能但增加内存使用仅在'rotary_chunk'勾选时有效",
)
clean_cuda_cache = gr.Checkbox(
label="清理CUDA内存缓存",
value=False,
info="及时释放GPU内存, 但会减慢推理速度。",
info="启用时,及时释放GPU内存但会减慢推理速度。",
)
gr.Markdown("### 异步卸载")
......@@ -877,14 +864,14 @@ def main():
lazy_load = gr.Checkbox(
label="启用延迟加载",
value=False,
info="在推理过程中延迟加载模型组件, 仅在'cpu_offload'勾选和使用量化Dit模型时有效",
info="在推理过程中延迟加载模型组件。需要CPU加载和DIT量化。",
)
offload_granularity = gr.Dropdown(
label="Dit卸载粒度",
choices=["block", "phase"],
value="phase",
info="设置Dit模型卸载粒度: 块或计算阶段",
info="设置Dit模型卸载粒度块或计算阶段",
)
offload_ratio = gr.Slider(
label="Dit模型卸载比例",
......@@ -926,25 +913,25 @@ def main():
label="Dit",
choices=["fp8", "int8", "bf16"],
value="bf16",
info="Dit模型的推理精度",
info="Dit模型的量化精度",
)
t5_quant_scheme = gr.Dropdown(
label="T5编码器",
choices=["fp8", "int8", "bf16"],
value="bf16",
info="T5编码器模型的推理精度",
info="T5编码器模型的量化精度",
)
clip_quant_scheme = gr.Dropdown(
label="Clip编码器",
choices=["fp8", "int8", "fp16"],
value="fp16",
info="Clip编码器的推理精度",
info="Clip编码器的量化精度",
)
precision_mode = gr.Dropdown(
label="敏感层精度",
label="敏感层精度模式",
choices=["fp32", "bf16"],
value="fp32",
info="选择用于敏感层(如norm层和embedding层)的数值精度",
info="选择用于关键模型组件(如归一化和嵌入层)的数值精度。FP32提供更高精度,而BF16在兼容硬件上提高性能。",
)
gr.Markdown("### 变分自编码器(VAE)")
......@@ -1006,47 +993,85 @@ def main():
use_ret_steps,
],
)
infer_btn.click(
fn=run_inference,
inputs=[
model_type,
task,
prompt,
negative_prompt,
image_path,
save_video_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
],
outputs=output_video,
)
if task == "i2v":
infer_btn.click(
fn=run_inference,
inputs=[
model_type,
prompt,
negative_prompt,
save_video_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
image_path,
],
outputs=output_video,
)
else:
infer_btn.click(
fn=run_inference,
inputs=[
model_type,
prompt,
negative_prompt,
save_video_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tiny_vae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
],
outputs=output_video,
)
demo.launch(share=True, server_port=args.server_port, server_name=args.server_name)
......@@ -1061,6 +1086,7 @@ if __name__ == "__main__":
default="wan2.1",
help="要使用的模型类别",
)
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="指定任务类型。'i2v'用于图像到视频转换,'t2v'用于文本到视频生成。")
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
args = parser.parse_args()
......@@ -1068,5 +1094,6 @@ if __name__ == "__main__":
global model_path, model_cls
model_path = args.model_path
model_cls = args.model_cls
task = args.task
main()
#!/bin/bash
lightx2v_path=/mtc/gushiqiao/llmc_workspace/lightx2v_new/lightx2v
model_path=/data/nvme0/gushiqiao/models/I2V/Wan2.1-I2V-14B-720P-Lightx2v-Step-Distill
# Lightx2v Gradio Demo Startup Script
# Supports both Image-to-Video (i2v) and Text-to-Video (t2v) modes
export CUDA_VISIBLE_DEVICES=7
# ==================== Configuration Area ====================
# ⚠️ Important: Please modify the following paths according to your actual environment
# 🚨 Storage Performance Tips 🚨
# 💾 Strongly recommend storing model files on SSD solid-state drives!
# 📈 SSD can significantly improve model loading speed and inference performance
# 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
# Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/path/to/lightx2v
# Model path configuration
# Image-to-video model path (for i2v tasks)
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
i2v_model_path=/path/to/Wan2.1-I2V-14B-720P-Lightx2v
# Text-to-video model path (for t2v tasks)
# Example: /path/to/Wan2.1-T2V-1.3B
t2v_model_path=/path/to/Wan2.1-T2V-1.3B
# Server configuration
server_name="0.0.0.0"
server_port=8032
# GPU configuration
gpu_id=0
# ==================== Environment Variables Setup ====================
export CUDA_VISIBLE_DEVICES=$gpu_id
export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python gradio_demo.py \
--model_path $model_path \
--server_name 0.0.0.0 \
--server_port 8005
# ==================== Parameter Parsing ====================
# Default task type
task="i2v"
# Default interface language
lang="zh"
# 解析命令行参数
while [[ $# -gt 0 ]]; do
case $1 in
--task)
task="$2"
shift 2
;;
--lang)
lang="$2"
shift 2
;;
--port)
server_port="$2"
shift 2
;;
--gpu)
gpu_id="$2"
export CUDA_VISIBLE_DEVICES=$gpu_id
shift 2
;;
--help)
echo "🎬 Lightx2v Gradio Demo Startup Script"
echo "=========================================="
echo "Usage: $0 [options]"
echo ""
echo "📋 Available options:"
echo " --task i2v|t2v Task type (default: i2v)"
echo " i2v: Image-to-video generation"
echo " t2v: Text-to-video generation"
echo " --lang zh|en Interface language (default: zh)"
echo " zh: Chinese interface"
echo " en: English interface"
echo " --port PORT Server port (default: 8032)"
echo " --gpu GPU_ID GPU device ID (default: 0)"
echo " --help Show this help message"
echo ""
echo "🚀 Usage examples:"
echo " $0 # Default startup for image-to-video mode"
echo " $0 --task i2v --lang zh --port 8032 # Start with specified parameters"
echo " $0 --task t2v --lang en --port 7860 # Text-to-video with English interface"
echo " $0 --task i2v --gpu 1 --port 8032 # Use GPU 1"
echo ""
echo "📝 Notes:"
echo " - Edit script to configure model paths before first use"
echo " - Ensure required Python dependencies are installed"
echo " - Recommended to use GPU with 8GB+ VRAM"
echo " - 🚨 Strongly recommend storing models on SSD for better performance"
exit 0
;;
*)
echo "Unknown parameter: $1"
echo "Use --help to see help information"
exit 1
;;
esac
done
# ==================== Parameter Validation ====================
if [[ "$task" != "i2v" && "$task" != "t2v" ]]; then
echo "Error: Task type must be 'i2v' or 't2v'"
exit 1
fi
if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
echo "Error: Language must be 'zh' or 'en'"
exit 1
fi
# Select model path based on task type
if [[ "$task" == "i2v" ]]; then
model_path=$i2v_model_path
echo "🎬 Starting Image-to-Video mode"
else
model_path=$t2v_model_path
echo "🎬 Starting Text-to-Video mode"
fi
# Check if model path exists
if [[ ! -d "$model_path" ]]; then
echo "❌ Error: Model path does not exist"
echo "📁 Path: $model_path"
echo "🔧 Solutions:"
echo " 1. Check model path configuration in script"
echo " 2. Ensure model files are properly downloaded"
echo " 3. Verify path permissions are correct"
echo " 4. 💾 Recommend storing models on SSD for faster loading"
exit 1
fi
# Select demo file based on language
if [[ "$lang" == "zh" ]]; then
demo_file="gradio_demo_zh.py"
echo "🌏 Using Chinese interface"
else
demo_file="gradio_demo.py"
echo "🌏 Using English interface"
fi
# Check if demo file exists
if [[ ! -f "$demo_file" ]]; then
echo "❌ Error: Demo file does not exist"
echo "📄 File: $demo_file"
echo "🔧 Solutions:"
echo " 1. Ensure script is run in the correct directory"
echo " 2. Check if file has been renamed or moved"
echo " 3. Re-clone or download project files"
exit 1
fi
# ==================== System Information Display ====================
echo "=========================================="
echo "🚀 Lightx2v Gradio Demo Starting..."
echo "=========================================="
echo "📁 Project path: $lightx2v_path"
echo "🤖 Model path: $model_path"
echo "🎯 Task type: $task"
echo "🌏 Interface language: $lang"
echo "🖥️ GPU device: $gpu_id"
echo "🌐 Server address: $server_name:$server_port"
echo "=========================================="
# Display system resource information
echo "💻 System resource information:"
free -h | grep -E "Mem|Swap"
echo ""
# Display GPU information
if command -v nvidia-smi &> /dev/null; then
echo "🎮 GPU information:"
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits | head -1
echo ""
fi
# ==================== Start Demo ====================
echo "🎬 Starting Gradio demo..."
echo "📱 Please access in browser: http://$server_name:$server_port"
echo "⏹️ Press Ctrl+C to stop service"
echo "🔄 First startup may take several minutes to load model..."
echo "=========================================="
# Start Python demo
python $demo_file \
--model_path "$model_path" \
--task "$task" \
--server_name "$server_name" \
--server_port "$server_port"
# python gradio_demo_zh.py \
# --model_path $model_path \
# --server_name 0.0.0.0 \
# --server_port 8005
# Display final system resource usage
echo ""
echo "=========================================="
echo "📊 Final system resource usage:"
free -h | grep -E "Mem|Swap"
......@@ -121,8 +121,9 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, weights):
next_block_idx = self.pin_memory_buffer.get_max_block_index()
def _async_prefetch_block(self, blocks, next_block_idx=None):
if next_block_idx is None:
next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0:
next_block_idx = 0
......@@ -137,7 +138,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with self.task_lock:
self.pending_tasks[obj_key] = True
phase = weights.blocks[next_block_idx].compute_phases[phase_idx]
phase = blocks[next_block_idx].compute_phases[phase_idx]
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
......@@ -149,20 +150,20 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with self.task_lock:
self.pending_tasks[obj_key] = True
block = weights.blocks[next_block_idx]
block = blocks[next_block_idx]
self.disk_task_queue.put((obj_key, (next_block_idx, block)))
def _sync_prefetch_block(self, weights):
def _sync_prefetch_block(self, blocks):
block_idx = 0
while not self.pin_memory_buffer.is_nearly_full():
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase = blocks[block_idx].compute_phases[phase_idx]
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
else:
block = weights.blocks[block_idx]
block = blocks[block_idx]
logger.info(f"Synchronous loading: block={block_idx}")
for phase in block.compute_phases:
phase.load_from_disk()
......@@ -170,11 +171,11 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
block_idx += 1
def prefetch_weights_from_disk(self, weights):
def prefetch_weights_from_disk(self, blocks):
if self.initial_prefetch_done:
return
self._sync_prefetch_block(weights)
self._sync_prefetch_block(blocks)
self.initial_prefetch_done = True
def prefetch_weights(self, block_idx, blocks):
......@@ -193,7 +194,15 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}")
else:
logger.info("Not find prefetch block={block_idx} task. This is a bug.")
logger.info("Not find prefetch block={block_idx} task.")
logger.info("Sync prefetch block={block_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
for phase_idx in self.phases_num:
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 15:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key)
......@@ -224,7 +233,14 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
else:
logger.info("Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug.")
logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
phase = self.pin_memory_buffer.get(obj_key)
......
......@@ -2,14 +2,9 @@ import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class QuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
......@@ -18,7 +13,7 @@ class QuantLinearInt8(nn.Module):
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
......@@ -44,18 +39,31 @@ class QuantLinearInt8(nn.Module):
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class QuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
......@@ -65,7 +73,6 @@ class QuantLinearFp8(nn.Module):
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
self.weight = self.weight.to(torch.float8_e4m3fn)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
......@@ -79,4 +86,19 @@ class QuantLinearFp8(nn.Module):
self.weight_scale.float(),
self.bias,
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
......@@ -51,11 +51,11 @@ class GELU(nn.Module):
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
def __init__(self, dim, eps=1e-6, dtype=torch.float16):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
......@@ -65,7 +65,7 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
......@@ -82,10 +82,10 @@ class T5Attention(nn.Module):
linear_cls = nn.Linear
# layers
self.q = linear_cls(dim, dim_attn, bias=False)
self.k = linear_cls(dim, dim_attn, bias=False)
self.v = linear_cls(dim, dim_attn, bias=False)
self.o = linear_cls(dim_attn, dim, bias=False)
self.q = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.k = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.v = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.o = linear_cls(dim_attn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
......@@ -125,7 +125,7 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
......@@ -138,9 +138,9 @@ class T5FeedForward(nn.Module):
else:
linear_cls = nn.Linear
# layers
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False)
self.fc2 = linear_cls(dim_ffn, dim, bias=False)
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype)
self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
......@@ -152,7 +152,7 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
......@@ -162,11 +162,11 @@ class T5SelfAttention(nn.Module):
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
self.norm1 = T5LayerNorm(dim, dtype=dtype)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme, dtype)
self.norm2 = T5LayerNorm(dim, dtype=dtype)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme, dtype=dtype)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
......@@ -212,7 +212,7 @@ class T5CrossAttention(nn.Module):
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
def __init__(self, num_buckets, num_heads, bidirectional, dtype=torch.bfloat16, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
......@@ -220,7 +220,7 @@ class T5RelativeEmbedding(nn.Module):
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
self.embedding = nn.Embedding(num_buckets, num_heads, dtype=dtype)
def forward(self, lq, lk):
device = self.embedding.weight.device
......@@ -252,7 +252,7 @@ class T5RelativeEmbedding(nn.Module):
class T5Encoder(nn.Module):
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
def __init__(self, dtype, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload
......@@ -266,11 +266,11 @@ class T5Encoder(nn.Module):
self.quant_scheme = quant_scheme
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim, dtype=dtype)
# initialize weights
# self.apply(init_weights)
......@@ -443,10 +443,10 @@ def _t5(
# init model
with torch.device(device):
model = model_cls(**kwargs)
model = model_cls(dtype=dtype, **kwargs)
# set device
model = model.to(dtype=dtype, device=device)
model = model.to(device=device)
return model
......@@ -511,9 +511,10 @@ class T5EncoderModel:
.requires_grad_(False)
)
logger.info(f"Loading weights from {self.checkpoint_path}")
logger.info(f"Start Loading weights from {self.checkpoint_path}")
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
logger.info(f"End Loading weights from {self.checkpoint_path}")
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
......
......@@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None, dtype=None):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -69,8 +69,8 @@ class SelfAttention(nn.Module):
else:
linear_cls = nn.Linear
self.to_qkv = linear_cls(dim, dim * 3)
self.proj = linear_cls(dim, dim)
self.to_qkv = linear_cls(dim, dim * 3, dtype=dtype)
self.proj = linear_cls(dim, dim, dtype=dtype)
def forward(self, x):
"""
......@@ -108,7 +108,21 @@ class SwiGLU(nn.Module):
class AttentionBlock(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5, quantized=False, quant_scheme=None):
def __init__(
self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation="quick_gelu",
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
dtype=torch.float16,
):
assert activation in ["quick_gelu", "gelu", "swi_glu"]
super().__init__()
self.dim = dim
......@@ -127,13 +141,18 @@ class AttentionBlock(nn.Module):
else:
linear_cls = nn.Linear
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme)
self.norm2 = LayerNorm(dim, eps=norm_eps)
self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme, dtype)
self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype)
else:
self.mlp = nn.Sequential(linear_cls(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), linear_cls(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.mlp = nn.Sequential(
linear_cls(dim, int(dim * mlp_ratio), dtype=dtype),
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
linear_cls(int(dim * mlp_ratio), dim, dtype=dtype),
nn.Dropout(proj_dropout),
)
def forward(self, x):
if self.post_norm:
......@@ -146,7 +165,7 @@ class AttentionBlock(nn.Module):
class AttentionPool(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5, dtype=torch.float16):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -159,11 +178,13 @@ class AttentionPool(nn.Module):
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.to_q = nn.Linear(dim, dim, dtype=dtype)
self.to_kv = nn.Linear(dim, dim * 2, dtype=dtype)
self.proj = nn.Linear(dim, dim, dtype=dtype)
self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), nn.Dropout(proj_dropout)
)
def forward(self, x):
"""
......@@ -191,6 +212,7 @@ class AttentionPool(nn.Module):
class VisionTransformer(nn.Module):
def __init__(
self,
dtype=torch.float16,
image_size=224,
patch_size=16,
dim=768,
......@@ -228,26 +250,26 @@ class VisionTransformer(nn.Module):
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype)
if pool_type in ("token", "token_fc"):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim))
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim, dtype=dtype))
self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim, dtype=dtype))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None
self.transformer = nn.Sequential(
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme) for _ in range(num_layers)]
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme, dtype) for _ in range(num_layers)]
)
self.post_norm = LayerNorm(dim, eps=norm_eps)
self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
# head
if pool_type == "token":
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
self.head = nn.Parameter(gain * torch.randn(dim, out_dim, dtype=dtype))
elif pool_type == "token_fc":
self.head = nn.Linear(dim, out_dim)
self.head = nn.Linear(dim, out_dim, dtype=dtype)
elif pool_type == "attn_pool":
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
......@@ -276,6 +298,7 @@ class VisionTransformer(nn.Module):
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
dtype=torch.float16,
embed_dim=1024,
image_size=224,
patch_size=14,
......@@ -317,6 +340,7 @@ class XLMRobertaCLIP(nn.Module):
# models
self.visual = VisionTransformer(
dtype=dtype,
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
......@@ -341,12 +365,11 @@ class XLMRobertaCLIP(nn.Module):
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
model = model_cls(dtype=dtype, **kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
model = model.to(device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
......@@ -395,20 +418,20 @@ class CLIPModel:
else:
self.checkpoint_path = checkpoint_path
logger.info(f"Loading weights from {self.checkpoint_path}")
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
logger.info(f"Start Loading weights from {self.checkpoint_path}")
self.model.load_state_dict(weight_dict)
logger.info(f"End Loading weights from {self.checkpoint_path}")
def visual(self, videos, args):
if args.cpu_offload:
......
import flash_attn
try:
import flash_attn
except ModuleNotFoundError:
flash_attn = None
import math
import torch
import torch.nn as nn
......
......@@ -104,7 +104,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num):
if block_idx == 0:
......@@ -132,7 +132,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights)
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context
......@@ -189,7 +189,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num):
for phase_idx in range(self.weights_stream_mgr.phases_num):
......@@ -236,7 +236,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights)
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del attn_out, y_out, y
......
......@@ -10,6 +10,7 @@ from safetensors import safe_open, torch as st
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
from qtorch.quant import float_quantize
def get_key_mapping_rules(direction, model_type):
......@@ -314,7 +315,8 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8):
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
if dtype == torch.float8_e4m3fn:
qmin, qmax = -448, 448
finfo = torch.finfo(dtype)
qmin, qmax = finfo.min, finfo.max
elif dtype == torch.int8:
qmin, qmax = -128, 127
......@@ -322,7 +324,9 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8):
scales = max_val / qmax
if dtype == torch.float8_e4m3fn:
w_q = torch.clamp(w / scales, qmin, qmax).to(dtype)
scaled_tensor = w / scales
scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(dtype)
else:
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(dtype)
......@@ -341,7 +345,8 @@ def quantize_model(
target_keys=["attn", "ffn"],
key_idx=2,
ignore_key=None,
dtype=torch.int8,
linear_dtype=torch.int8,
non_linear_dtype=torch.float,
):
"""
Quantize model weights in-place
......@@ -370,16 +375,20 @@ def quantize_model(
# Skip non-tensors, small tensors, and non-2D tensors
if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
continue
# Check if key matches target modules
parts = key.split(".")
if len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
continue
try:
# Quantize tensor and store results
w_q, scales = quantize_tensor(tensor, w_bit, dtype)
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype)
# Replace original tensor and store scales
weights[key] = w_q
......@@ -500,7 +509,8 @@ def convert_weights(args):
target_keys=args.target_keys,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
dtype=args.dtype,
linear_dtype=args.linear_dtype,
non_linear_dtype=args.non_linear_dtype,
)
os.makedirs(args.output, exist_ok=True)
......@@ -637,10 +647,17 @@ def main():
help="Device to use for quantization (cpu/cuda)",
)
parser.add_argument(
"--dtype",
"--linear_dtype",
type=str,
choices=["torch.int8", "torch.float8_e4m3fn"],
help="Data type for quantization",
help="Data type for linear",
)
parser.add_argument(
"--non_linear_dtype",
type=str,
default="torch.float32",
choices=["torch.bfloat16", "torch.float16"],
help="Data type for non-linear",
)
parser.add_argument("--lora_path", type=str, nargs="*", help="Path(s) to LoRA file(s). Can specify multiple paths separated by spaces.")
parser.add_argument(
......@@ -654,12 +671,8 @@ def main():
args = parser.parse_args()
if args.quantized:
if args.dtype == "torch.int8":
args.dtype = torch.int8
elif args.dtype == "torch.float8_e4m3fn":
args.dtype = torch.float8_e4m3fn
else:
raise ValueError(f"Not support dtype :{args.dtype}")
args.linear_dtype = eval(args.linear_dtype)
args.non_linear_dtype = eval(args.non_linear_dtype)
model_type_keys_map = {
"wan_dit": {
......
......@@ -36,7 +36,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--model_type wan_dit \
--quantized \
--save_by_block
......@@ -48,7 +48,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--model_type wan_dit \
--quantized \
--save_by_block
......@@ -62,7 +62,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0 \
......@@ -78,7 +78,7 @@ python converter.py \
--output /Path/To/output \
--output_ext ..safetensors \
--output_name hunyuan_int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--model_type hunyuan_dit \
--quantized
```
......@@ -89,7 +89,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--model_type hunyuan_dit \
--quantized
```
......@@ -103,7 +103,8 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \
--quantized
```
......@@ -111,10 +112,11 @@ python converter.py \
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/fp8 \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \
--quantized
```
......@@ -128,7 +130,8 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth \
--output_name clip-int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
......@@ -136,10 +139,11 @@ python converter.py \
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /Path/To/output \
--output ./output \
--output_ext .pth \
--output_name clip-fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
```
......@@ -36,7 +36,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--model_type wan_dit \
--quantized \
--save_by_block
......@@ -48,7 +48,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--model_type wan_dit \
--quantized \
--save_by_block
......@@ -62,7 +62,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0 \
......@@ -78,7 +78,7 @@ python converter.py \
--output /Path/To/output \
--output_ext ..safetensors \
--output_name hunyuan_int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--model_type hunyuan_dit \
--quantized
```
......@@ -89,7 +89,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--model_type hunyuan_dit \
--quantized
```
......@@ -103,7 +103,8 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \
--quantized
```
......@@ -111,10 +112,11 @@ python converter.py \
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/fp8 \
--output_ext .pth\
--output_name models_t5_umt5-xxl-enc-fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \
--quantized
```
......@@ -128,7 +130,8 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth \
--output_name clip-int8 \
--dtype torch.int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
......@@ -136,10 +139,11 @@ python converter.py \
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /Path/To/output \
--output ./output \
--output_ext .pth \
--output_name clip-fp8 \
--dtype torch.float8_e4m3fn \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment