Commit a1ebc651 authored by xuwx1's avatar xuwx1
Browse files

updata lightx2v

parent 5a4db490
Pipeline #3149 canceled with stages
*.pth
*.pt
*.onnx
*.pk
*.model
*.zip
*.tar
*.pyc
*.log
*.o
*.so
*.a
*.exe
*.out
.idea
**.DS_Store**
**/__pycache__/**
**.swp
.vscode/
.env
.log
*.pid
*.ipynb*
*.mp4
build/
dist/
.cache/
server_cache/
app/.gradio/
*.pkl
# Follow https://verdantfox.com/blog/how-to-use-git-pre-commit-hooks-the-hard-way-and-the-easy-way
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
- id: ruff-format
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-added-large-files
args: ['--maxkb=3000'] # Allow files up to 3MB
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# LightX2V
<div align="center" style="font-family: charter;">
<h1>⚡️ LightX2V:<br> 轻量级视频生成推理框架</h1>
<img alt="logo" src="assets/img_lightx2v.png" width=75%></img>
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/ModelTC/lightx2v)
[![Doc](https://img.shields.io/badge/docs-English-99cc2)](https://lightx2v-en.readthedocs.io/en/latest)
[![Doc](https://img.shields.io/badge/文档-中文-99cc2)](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest)
[![Papers](https://img.shields.io/badge/论文集-中文-99cc2)](https://lightx2v-papers-zhcn.readthedocs.io/zh-cn/latest)
[![Docker](https://img.shields.io/badge/Docker-2496ED?style=flat&logo=docker&logoColor=white)](https://hub.docker.com/r/lightx2v/lightx2v/tags)
**\[ [English](README.md) | 中文 \]**
</div>
--------------------------------------------------------------------------------
**LightX2V** 是一个先进的轻量级视频生成推理框架,专为提供高效、高性能的视频合成解决方案而设计。该统一平台集成了多种前沿的视频生成技术,支持文本生成视频(T2V)和图像生成视频(I2V)等多样化生成任务。**X2V 表示将不同的输入模态(X,如文本或图像)转换为视频输出(V)**
> 🌐 **立即在线体验!** 无需安装即可体验 LightX2V:**[LightX2V 在线服务](https://x2v.light-ai.top/login)** - 免费、轻量、快速的AI数字人视频生成平台。
## :fire: 最新动态
- **2025年12月4日:** 🚀 支持 GGUF 格式模型推理,以及在寒武纪 MLU590、MetaX C500 硬件上的部署。
- **2025年11月24日:** 🚀 我们发布了HunyuanVideo-1.5的4步蒸馏模型!这些模型支持**超快速4步推理**,无需CFG配置,相比标准50步推理可实现约**25倍加速**。现已提供基础版本和FP8量化版本:[Hy1.5-Distill-Models](https://huggingface.co/lightx2v/Hy1.5-Distill-Models)
- **2025年11月21日:** 🚀 我们Day0支持了[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)的视频生成模型,同样GPU数量,LightX2V可带来约2倍以上的速度提升,并支持更低显存GPU部署(如24G RTX4090)。支持CFG并行/Ulysses并行,高效Offload,TeaCache/MagCache等技术。同时支持沐曦,寒武纪等国产芯片部署。我们很快将在我们的[HuggingFace主页](https://huggingface.co/lightx2v)更新更多模型,包括步数蒸馏,VAE蒸馏等相关模型。量化模型和轻量VAE模型现已可用:[Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models)用于量化推理,[HunyuanVideo-1.5轻量TAE](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)用于快速VAE解码。使用教程参考[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts/hunyuan_video_15),或查看[示例目录](https://github.com/ModelTC/LightX2V/tree/main/examples)获取代码示例。
## 🏆 性能测试数据 (更新于 2025.12.01)
### 📊 推理框架之间性能对比 (H100)
| Framework | GPUs | Step Time | Speedup |
|-----------|---------|---------|---------|
| Diffusers | 1 | 9.77s/it | 1x |
| xDiT | 1 | 8.93s/it | 1.1x |
| FastVideo | 1 | 7.35s/it | 1.3x |
| SGL-Diffusion | 1 | 6.13s/it | 1.6x |
| **LightX2V** | 1 | **5.18s/it** | **1.9x** 🚀 |
| FastVideo | 8 | 2.94s/it | 1x |
| xDiT | 8 | 2.70s/it | 1.1x |
| SGL-Diffusion | 8 | 1.19s/it | 2.5x |
| **LightX2V** | 8 | **0.75s/it** | **3.9x** 🚀 |
### 📊 推理框架之间性能对比 (RTX 4090D)
| Framework | GPUs | Step Time | Speedup |
|-----------|---------|---------|---------|
| Diffusers | 1 | 30.50s/it | 1x |
| FastVideo | 1 | 22.66s/it | 1.3x |
| xDiT | 1 | OOM | OOM |
| SGL-Diffusion | 1 | OOM | OOM |
| **LightX2V** | 1 | **20.26s/it** | **1.5x** 🚀 |
| FastVideo | 8 | 15.48s/it | 1x |
| xDiT | 8 | OOM | OOM |
| SGL-Diffusion | 8 | OOM | OOM |
| **LightX2V** | 8 | **4.75s/it** | **3.3x** 🚀 |
### 📊 LightX2V不同配置之间性能对比
| Framework | GPU | Configuration | Step Time | Speedup |
|-----------|-----|---------------|-----------|---------------|
| **LightX2V** | H100 | 8 GPUs + cfg | 0.75s/it | 1x |
| **LightX2V** | H100 | 8 GPUs + no cfg | 0.39s/it | 1.9x |
| **LightX2V** | H100 | **8 GPUs + no cfg + fp8** | **0.35s/it** | **2.1x** 🚀 |
| **LightX2V** | 4090D | 8 GPUs + cfg | 4.75s/it | 1x |
| **LightX2V** | 4090D | 8 GPUs + no cfg | 3.13s/it | 1.5x |
| **LightX2V** | 4090D | **8 GPUs + no cfg + fp8** | **2.35s/it** | **2.0x** 🚀 |
**注意**: 所有以上性能数据均在 Wan2.1-I2V-14B-480P(40 steps, 81 frames) 上测试。此外,我们[HuggingFace 主页](https://huggingface.co/lightx2v)还提供了4步蒸馏模型。
## 💡 快速开始
详细使用说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)**
**我们强烈推荐使用 Docker 环境,这是最简单快捷的环境安装方式。具体参考:文档中的快速入门章节。**
### 从 Git 安装
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
### 从源码构建
```bash
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
uv pip install -v . # pip install -v .
```
### (可选)安装注意力/量化算子
注意力算子安装说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/getting_started/quickstart.html#step-4-install-attention-operators) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/quickstart.html#id9)**
### 使用示例
```python
# examples/wan/wan_i2v.py
"""
Wan2.2 image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 I2V task
# For wan2.1, use model_cls="wan2.1"
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="configs/wan22/wan_moe_i2v.json"
# )
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For Wan models, supports both "block" and "phase"
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
sample_shift=5.0,
)
# Generation parameters
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
image_path="/path/to/img_0.jpg"
save_result_path = "/path/to/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
```
> 💡 **更多示例**: 更多使用案例,包括量化、卸载、缓存等进阶配置,请参考 [examples 目录](https://github.com/ModelTC/LightX2V/tree/main/examples)。
## 🤖 支持的模型生态
### 官方开源模型
-[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)
-[Wan2.1 & Wan2.2](https://huggingface.co/Wan-AI/)
-[Qwen-Image](https://huggingface.co/Qwen/Qwen-Image)
-[Qwen-Image-Edit](https://huggingface.co/spaces/Qwen/Qwen-Image-Edit)
-[Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)
### 量化模型和蒸馏模型/Lora (**🚀 推荐:4步推理**)
-[Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
-[Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
-[Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras)
-[Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras)
### 轻量级自编码器模型(**🚀 推荐:推理快速 + 内存占用低**)
-[Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
### 自回归模型
-[Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
-[Self-Forcing](https://github.com/guandeh17/Self-Forcing)
-[Matrix-Game-2.0](https://huggingface.co/Skywork/Matrix-Game-2.0)
🔔 可以关注我们的[HuggingFace主页](https://huggingface.co/lightx2v),及时获取我们团队的模型。
💡 参考[模型结构文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/model_structure.html)快速上手 LightX2V
## 🚀 前端展示
我们提供了多种前端界面部署方式:
- **🎨 Gradio界面**: 简洁易用的Web界面,适合快速体验和原型开发
- 📖 [Gradio部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
- **🎯 ComfyUI界面**: 强大的节点式工作流界面,支持复杂的视频生成任务
- 📖 [ComfyUI部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_comfyui.html)
- **🚀 Windows一键部署**: 专为Windows用户设计的便捷部署方案,支持自动环境配置和智能参数优化
- 📖 [Windows一键部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)
**💡 推荐方案**:
- **首次使用**: 建议选择Windows一键部署方案
- **高级用户**: 推荐使用ComfyUI界面获得更多自定义选项
- **快速体验**: Gradio界面提供最直观的操作体验
## 🚀 核心特性
### 🎯 **极致性能优化**
- **🔥 SOTA推理速度**: 通过步数蒸馏和系统优化实现**20倍**极速加速(单GPU)
- **⚡️ 革命性4步蒸馏**: 将原始40-50步推理压缩至仅需4步,且无需CFG配置
- **🛠️ 先进算子支持**: 集成顶尖算子,包括[Sage Attention](https://github.com/thu-ml/SageAttention)[Flash Attention](https://github.com/Dao-AILab/flash-attention)[Radial Attention](https://github.com/mit-han-lab/radial-attention)[q8-kernel](https://github.com/KONAKONA666/q8_kernels)[sgl-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)[vllm](https://github.com/vllm-project/vllm)
### 💾 **资源高效部署**
- **💡 突破硬件限制**: **仅需8GB显存 + 16GB内存**即可运行14B模型生成480P/720P视频
- **🔧 智能参数卸载**: 先进的磁盘-CPU-GPU三级卸载架构,支持阶段/块级别的精细化管理
- **⚙️ 全面量化支持**: 支持`w8a8-int8``w8a8-fp8``w4a4-nvfp4`等多种量化策略
### 🎨 **丰富功能生态**
- **📈 智能特征缓存**: 智能缓存机制,消除冗余计算,提升效率
- **🔄 并行推理加速**: 多GPU并行处理,显著提升性能表现
- **📱 灵活部署选择**: 支持Gradio、服务化部署、ComfyUI等多种部署方式
- **🎛️ 动态分辨率推理**: 自适应分辨率调整,优化生成质量
- **🎞️ 视频帧插值**: 基于RIFE的帧插值技术,实现流畅的帧率提升
## 📚 技术文档
### 📖 **方法教程**
- [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) - 量化策略全面指南
- [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) - 智能缓存机制详解
- [注意力机制](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) - 前沿注意力算子
- [参数卸载](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/offload.html) - 三级存储架构
- [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) - 多GPU加速策略
- [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) - U型分辨率策略
- [步数蒸馏](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html) - 4步推理技术
- [视频帧插值](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/video_frame_interpolation.html) - 基于RIFE的帧插值技术
### 🛠️ **部署指南**
- [低资源场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_resource.html) - 优化的8GB显存解决方案
- [低延迟场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_latency.html) - 极速推理优化
- [Gradio部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) - Web界面搭建
- [服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html) - 生产级API服务部署
- [Lora模型部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/lora_deploy.html) - Lora灵活部署
## 🧾 代码贡献指南
我们通过自动化的预提交钩子来保证代码质量,确保项目代码格式的一致性。
> [!TIP]
> **安装说明:**
>
> 1. 安装必要的依赖:
> ```shell
> pip install ruff pre-commit
> ```
>
> 2. 提交前运行:
> ```shell
> pre-commit run --all-files
> ```
感谢您为LightX2V的改进做出贡献!
## 🤝 致谢
我们向所有启发和促进LightX2V开发的模型仓库和研究社区表示诚挚的感谢。此框架基于开源社区的集体努力而构建。
## 🌟 Star 历史
[![Star History Chart](https://api.star-history.com/svg?repos=ModelTC/lightx2v&type=Timeline)](https://star-history.com/#ModelTC/lightx2v&Timeline)
## ✏️ 引用
如果您发现LightX2V对您的研究有用,请考虑引用我们的工作:
```bibtex
@misc{lightx2v,
author = {LightX2V Contributors},
title = {LightX2V: Light Video Generation Inference Framework},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ModelTC/lightx2v}},
}
```
## 📞 联系与支持
如有任何问题、建议或需要支持,欢迎通过以下方式联系我们:
- 🐛 [GitHub Issues](https://github.com/ModelTC/lightx2v/issues) - 错误报告和功能请求
---
<div align="center">
由 LightX2V 团队用 ❤️ 构建
</div>
# Gradio Demo
Please refer our gradio deployment doc:
[English doc: Gradio Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html)
[中文文档: Gradio 部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
## 🚀 Quick Start (快速开始)
For Windows users, we provide a convenient one-click deployment solution with automatic environment configuration and intelligent parameter optimization. Please refer to the [One-Click Gradio Launch](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_local_windows.html) section for detailed instructions.
对于Windows用户,我们提供了便捷的一键部署方式,支持自动环境配置和智能参数优化。详细操作请参考[一键启动Gradio](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)章节。
import argparse
import gc
import glob
import importlib.util
import json
import os
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
import random
from datetime import datetime
import gradio as gr
import psutil
import torch
from loguru import logger
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import get_default_config
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
logger.add(
"inference_logs.log",
rotation="100 MB",
encoding="utf-8",
enqueue=True,
backtrace=True,
diagnose=True,
)
MAX_NUMPY_SEED = 2**32 - 1
def scan_model_path_contents(model_path):
"""Scan model_path directory and return available files and subdirectories"""
if not model_path or not os.path.exists(model_path):
return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
dirs = []
files = []
safetensors_dirs = []
pth_files = []
try:
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
# Check if directory contains safetensors files
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
except Exception as e:
logger.warning(f"Failed to scan directory: {e}")
return {
"dirs": sorted(dirs),
"files": sorted(files),
"safetensors_dirs": sorted(safetensors_dirs),
"pth_files": sorted(pth_files),
}
def get_dit_choices(model_path, model_type="wan2.1"):
"""Get Diffusion model options (filtered by model type)"""
contents = scan_model_path_contents(model_path)
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
fp8_supported = is_fp8_supported_gpu()
if model_type == "wan2.1":
# wan2.1: filter files/dirs containing wan2.1 or Wan2.1
def is_valid(name):
name_lower = name.lower()
if "wan2.1" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
else:
# wan2.2: filter files/dirs containing wan2.2 or Wan2.2
def is_valid(name):
name_lower = name.lower()
if "wan2.2" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
# Filter matching directories and files
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_high_noise_choices(model_path):
"""Get high noise model options (files/dirs containing high_noise)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_low_noise_choices(model_path):
"""Get low noise model options (files/dirs containing low_noise)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_t5_choices(model_path):
"""Get T5 model options (.pth or .safetensors files containing t5 keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_clip_choices(model_path):
"""Get CLIP model options (.pth or .safetensors files containing clip keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_vae_choices(model_path):
"""Get VAE model options (.pth or .safetensors files containing vae/VAE/tae keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def detect_quant_scheme(model_name):
"""Automatically detect quantization scheme from model name
- If model name contains "int8" → "int8"
- If model name contains "fp8" and device supports → "fp8"
- Otherwise return None (no quantization)
"""
if not model_name:
return None
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
# Device doesn't support fp8, return None (use default precision)
return None
return None
def update_model_path_options(model_path, model_type="wan2.1"):
"""Update all model path selectors when model_path or model_type changes"""
dit_choices = get_dit_choices(model_path, model_type)
high_noise_choices = get_high_noise_choices(model_path)
low_noise_choices = get_low_noise_choices(model_path)
t5_choices = get_t5_choices(model_path)
clip_choices = get_clip_choices(model_path)
vae_choices = get_vae_choices(model_path)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
)
def generate_random_seed():
return random.randint(0, MAX_NUMPY_SEED)
def is_module_installed(module_name):
try:
spec = importlib.util.find_spec(module_name)
return spec is not None
except ModuleNotFoundError:
return False
def get_available_quant_ops():
available_ops = []
vllm_installed = is_module_installed("vllm")
if vllm_installed:
available_ops.append(("vllm", True))
else:
available_ops.append(("vllm", False))
sgl_installed = is_module_installed("sgl_kernel")
if sgl_installed:
available_ops.append(("sgl", True))
else:
available_ops.append(("sgl", False))
q8f_installed = is_module_installed("q8_kernels")
if q8f_installed:
available_ops.append(("q8f", True))
else:
available_ops.append(("q8f", False))
return available_ops
def get_available_attn_ops():
available_ops = []
vllm_installed = is_module_installed("flash_attn")
if vllm_installed:
available_ops.append(("flash_attn2", True))
else:
available_ops.append(("flash_attn2", False))
sgl_installed = is_module_installed("flash_attn_interface")
if sgl_installed:
available_ops.append(("flash_attn3", True))
else:
available_ops.append(("flash_attn3", False))
sage_installed = is_module_installed("sageattention")
if sage_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops
def get_gpu_memory(gpu_idx=0):
if not torch.cuda.is_available():
return 0
try:
with torch.cuda.device(gpu_idx):
memory_info = torch.cuda.mem_get_info()
total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
return total_memory
except Exception as e:
logger.warning(f"Failed to get GPU memory: {e}")
return 0
def get_cpu_memory():
available_bytes = psutil.virtual_memory().available
return available_bytes / 1024**3
def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
import psutil
if hasattr(psutil, "virtual_memory"):
if os.name == "posix":
try:
os.system("sync")
except: # noqa
pass
except: # noqa
pass
def generate_unique_filename(output_dir):
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join(output_dir, f"{timestamp}.mp4")
def is_fp8_supported_gpu():
if not torch.cuda.is_available():
return False
compute_capability = torch.cuda.get_device_capability(0)
major, minor = compute_capability
return (major == 8 and minor == 9) or (major >= 9)
def is_ada_architecture_gpu():
if not torch.cuda.is_available():
return False
try:
gpu_name = torch.cuda.get_device_name(0).upper()
ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
return any(keyword in gpu_name for keyword in ada_keywords)
except Exception as e:
logger.warning(f"Failed to get GPU name: {e}")
return False
def get_quantization_options(model_path):
"""Get quantization options dynamically based on model_path"""
import os
# Check subdirectories
subdirs = ["original", "fp8", "int8"]
has_subdirs = {subdir: os.path.exists(os.path.join(model_path, subdir)) for subdir in subdirs}
# Check original files in root directory
t5_bf16_exists = os.path.exists(os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth"))
clip_fp16_exists = os.path.exists(os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"))
# Generate options
def get_choices(has_subdirs, original_type, fp8_type, int8_type, fallback_type, has_original_file=False):
choices = []
if has_subdirs["original"]:
choices.append(original_type)
if has_subdirs["fp8"]:
choices.append(fp8_type)
if has_subdirs["int8"]:
choices.append(int8_type)
# If no subdirectories but original file exists, add original type
if has_original_file:
if not choices or "original" not in choices:
choices.append(original_type)
# If no options at all, use default value
if not choices:
choices = [fallback_type]
return choices, choices[0]
# DIT options
dit_choices, dit_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16")
# T5 options - check if original file exists
t5_choices, t5_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16", t5_bf16_exists)
# CLIP options - check if original file exists
clip_choices, clip_default = get_choices(has_subdirs, "fp16", "fp8", "int8", "fp16", clip_fp16_exists)
return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
def determine_model_cls(model_type, dit_name, high_noise_name):
"""Determine model_cls based on model type and file name"""
# Determine file name to check
if model_type == "wan2.1":
check_name = dit_name.lower() if dit_name else ""
is_distill = "4step" in check_name
return "wan2.1_distill" if is_distill else "wan2.1"
else:
# wan2.2
check_name = high_noise_name.lower() if high_noise_name else ""
is_distill = "4step" in check_name
return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
global_runner = None
current_config = None
cur_dit_path = None
cur_t5_path = None
cur_clip_path = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
for op_name, is_installed in available_quant_ops:
status_text = "✅ Installed" if is_installed else "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops()
# Priority order
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
# Sort by priority, installed ones first, uninstalled ones last
attn_op_choices = []
attn_op_dict = dict(available_attn_ops)
# Add installed ones first (by priority)
for op_name in attn_priority:
if op_name in attn_op_dict and attn_op_dict[op_name]:
status_text = "✅ Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# Add uninstalled ones (by priority)
for op_name in attn_priority:
if op_name in attn_op_dict and not attn_op_dict[op_name]:
status_text = "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# Add other operators not in priority list (installed ones first)
other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # Installed ones first
status_text = "✅ Installed" if is_installed else "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
def run_inference(
prompt,
negative_prompt,
save_result_path,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_cfg,
cfg_scale,
fps,
use_tiling_vae,
lazy_load,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path=None,
):
cleanup_memory()
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path, model_cls
global cur_dit_path, cur_t5_path, cur_clip_path
task = task_type_input
model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
logger.info(f"Auto-determined model_cls: {model_cls} (Model type: {model_type_input})")
if model_type_input == "wan2.1":
dit_quant_detected = detect_quant_scheme(dit_path_input)
else:
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
t5_quant_detected = detect_quant_scheme(t5_path_input)
clip_quant_detected = detect_quant_scheme(clip_path_input)
logger.info(f"Auto-detected quantization scheme - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
if model_path_input and model_path_input.strip():
model_path = model_path_input.strip()
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f)
else:
model_config = {}
save_result_path = generate_unique_filename(output_dir)
is_dit_quant = dit_quant_detected != "bf16"
is_t5_quant = t5_quant_detected != "bf16"
is_clip_quant = clip_quant_detected != "fp16"
dit_quantized_ckpt = None
dit_original_ckpt = None
high_noise_quantized_ckpt = None
low_noise_quantized_ckpt = None
high_noise_original_ckpt = None
low_noise_original_ckpt = None
if is_dit_quant:
dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
if "wan2.1" in model_cls:
dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
else:
dit_quantized_ckpt = "Default"
if "wan2.1" in model_cls:
dit_original_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
# Use frontend-selected T5 path
if is_t5_quant:
t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
t5_original_ckpt = None
else:
t5_quantized_ckpt = None
t5_quant_scheme = None
t5_original_ckpt = os.path.join(model_path, t5_path_input)
# Use frontend-selected CLIP path
if is_clip_quant:
clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
clip_original_ckpt = None
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
clip_original_ckpt = os.path.join(model_path, clip_path_input)
if model_type_input == "wan2.1":
current_dit_path = dit_path_input
else:
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
current_t5_path = t5_path_input
current_clip_path = clip_path_input
needs_reinit = (
lazy_load
or unload_modules
or global_runner is None
or current_config is None
or cur_dit_path is None
or cur_dit_path != current_dit_path
or cur_t5_path is None
or cur_t5_path != current_t5_path
or cur_clip_path is None
or cur_clip_path != current_clip_path
)
if cfg_scale == 1:
enable_cfg = False
else:
enable_cfg = True
vae_name_lower = vae_path_input.lower() if vae_path_input else ""
use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
use_lightvae = "lightvae" in vae_name_lower
need_scaled = "lighttae" in vae_name_lower
logger.info(f"VAE configuration - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
config_graio = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]),
"target_height": int(resolution.split("x")[1]),
"self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale,
"sample_shift": sample_shift,
"fps": fps,
"feature_caching": "NoCaching",
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
"t5_cpu_offload": t5_cpu_offload,
"clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"dit_quantized": is_dit_quant,
"dit_quant_scheme": dit_quant_scheme,
"dit_quantized_ckpt": dit_quantized_ckpt,
"dit_original_ckpt": dit_original_ckpt,
"high_noise_quantized_ckpt": high_noise_quantized_ckpt,
"low_noise_quantized_ckpt": low_noise_quantized_ckpt,
"high_noise_original_ckpt": high_noise_original_ckpt,
"low_noise_original_ckpt": low_noise_original_ckpt,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quantized_ckpt,
"t5_quant_scheme": t5_quant_scheme,
"clip_original_ckpt": clip_original_ckpt,
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quantized_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": os.path.join(model_path, vae_path_input),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"use_lightvae": use_lightvae,
"need_scaled": need_scaled,
"lazy_load": lazy_load,
"rope_chunk": rope_chunk,
"rope_chunk_size": rope_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
}
args = argparse.Namespace(
model_cls=model_cls,
seed=seed,
task=task,
model_path=model_path,
prompt_enhancer=None,
prompt=prompt,
negative_prompt=negative_prompt,
image_path=image_path,
save_result_path=save_result_path,
return_result_tensor=False,
)
config = get_default_config()
config.update({k: v for k, v in vars(args).items()})
config.update(model_config)
config.update(config_graio)
logger.info(f"Using model: {model_path}")
logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
# Initialize or reuse the runner
runner = global_runner
if needs_reinit:
if runner is not None:
del runner
torch.cuda.empty_cache()
gc.collect()
from lightx2v.infer import init_runner # noqa
runner = init_runner(config)
input_info = set_input_info(args)
current_config = config
cur_dit_path = current_dit_path
cur_t5_path = current_t5_path
cur_clip_path = current_clip_path
if not lazy_load:
global_runner = runner
else:
runner.config = config
runner.run_pipeline(input_info)
cleanup_memory()
return save_result_path
def handle_lazy_load_change(lazy_load_enabled):
"""Handle lazy_load checkbox change to automatically enable unload_modules"""
return gr.update(value=lazy_load_enabled)
def auto_configure(resolution):
"""Auto-configure inference options based on machine configuration and resolution"""
default_config = {
"lazy_load_val": False,
"rope_chunk_val": False,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": False,
"cpu_offload_val": False,
"offload_granularity_val": "block",
"t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False,
"attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1],
"use_tiling_vae_val": False,
}
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["vllm", "sgl", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
default_config["attention_type_val"] = dict(attn_op_choices)[op]
break
for op in quant_op_priority:
if dict(available_quant_ops).get(op):
default_config["quant_op_val"] = dict(quant_op_choices)[op]
break
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
res = "720p"
elif resolution in [
"960x544",
"544x960",
]:
res = "540p"
else:
res = "480p"
if res == "720p":
gpu_rules = [
(80, {}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"use_tiling_vae_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": True,
},
),
]
else:
gpu_rules = [
(80, {}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
]
cpu_rules = [
(128, {}),
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
default_config.update(updates)
break
for threshold, updates in cpu_rules:
if cpu_memory >= threshold:
default_config.update(updates)
break
return (
gr.update(value=default_config["lazy_load_val"]),
gr.update(value=default_config["rope_chunk_val"]),
gr.update(value=default_config["rope_chunk_size_val"]),
gr.update(value=default_config["clean_cuda_cache_val"]),
gr.update(value=default_config["cpu_offload_val"]),
gr.update(value=default_config["offload_granularity_val"]),
gr.update(value=default_config["t5_cpu_offload_val"]),
gr.update(value=default_config["clip_cpu_offload_val"]),
gr.update(value=default_config["vae_cpu_offload_val"]),
gr.update(value=default_config["unload_modules_val"]),
gr.update(value=default_config["attention_type_val"]),
gr.update(value=default_config["quant_op_val"]),
gr.update(value=default_config["use_tiling_vae_val"]),
)
css = """
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
/* Model configuration area styles */
.model-config {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
}
/* Input parameters area styles */
.input-params {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
}
/* Output video area styles */
.output-video {
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
}
/* Generate button styles */
.generate-btn {
width: 100%;
margin-top: 20px;
padding: 15px 30px !important;
font-size: 18px !important;
font-weight: bold !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
border-radius: 10px !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
transition: all 0.3s ease !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
/* Accordion header styles */
.model-config .gr-accordion-header,
.input-params .gr-accordion-header,
.output-video .gr-accordion-header {
font-size: 20px !important;
font-weight: bold !important;
padding: 15px !important;
}
/* Optimize spacing */
.gr-row {
margin-bottom: 15px;
}
/* Video player styles */
.output-video video {
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
"""
def main():
with gr.Blocks(title="Lightx2v (Lightweight Video Inference and Generation Engine)") as demo:
gr.Markdown(f"# 🎬 LightX2V Video Generator")
gr.HTML(f"<style>{css}</style>")
# Main layout: left and right columns
with gr.Row():
# Left: configuration and input area
with gr.Column(scale=5):
# Model configuration area
with gr.Accordion("🗂️ Model Configuration", open=True, elem_classes=["model-config"]):
# FP8 support notice
if not is_fp8_supported_gpu():
gr.Markdown("⚠️ **Your device does not support FP8 inference**. Models containing FP8 have been automatically hidden.")
# Hidden state components
model_path_input = gr.Textbox(value=model_path, visible=False)
# Model type + Task type
with gr.Row():
model_type_input = gr.Radio(
label="Model Type",
choices=["wan2.1", "wan2.2"],
value="wan2.1",
info="wan2.2 requires separate high noise and low noise models",
)
task_type_input = gr.Radio(
label="Task Type",
choices=["i2v", "t2v"],
value="i2v",
info="i2v: Image-to-video, t2v: Text-to-video",
)
# wan2.1: Diffusion model (single row)
with gr.Row() as wan21_row:
dit_path_input = gr.Dropdown(
label="🎨 Diffusion Model",
choices=get_dit_choices(model_path, "wan2.1"),
value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
allow_custom_value=True,
visible=True,
)
# wan2.2 specific: high noise model + low noise model (hidden by default)
with gr.Row(visible=False) as wan22_row:
high_noise_path_input = gr.Dropdown(
label="🔊 High Noise Model",
choices=get_high_noise_choices(model_path),
value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
allow_custom_value=True,
)
low_noise_path_input = gr.Dropdown(
label="🔇 Low Noise Model",
choices=get_low_noise_choices(model_path),
value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
allow_custom_value=True,
)
# Text encoder (single row)
with gr.Row():
t5_path_input = gr.Dropdown(
label="📝 Text Encoder",
choices=get_t5_choices(model_path),
value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
allow_custom_value=True,
)
# Image encoder + VAE decoder
with gr.Row():
clip_path_input = gr.Dropdown(
label="🖼️ Image Encoder",
choices=get_clip_choices(model_path),
value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
allow_custom_value=True,
)
vae_path_input = gr.Dropdown(
label="🎞️ VAE Decoder",
choices=get_vae_choices(model_path),
value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
allow_custom_value=True,
)
# Attention operator and quantization matrix multiplication operator
with gr.Row():
attention_type = gr.Dropdown(
label="⚡ Attention Operator",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1] if attn_op_choices else "",
info="Use appropriate attention operators to accelerate inference",
)
quant_op = gr.Dropdown(
label="Quantization Matmul Operator",
choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1],
info="Select quantization matrix multiplication operator to accelerate inference",
interactive=True,
)
# Determine if model is distill version
def is_distill_model(model_type, dit_path, high_noise_path):
"""Determine if model is distill version based on model type and path"""
if model_type == "wan2.1":
check_name = dit_path.lower() if dit_path else ""
else:
check_name = high_noise_path.lower() if high_noise_path else ""
return "4step" in check_name
# Model type change event
def on_model_type_change(model_type, model_path_val):
if model_type == "wan2.2":
return gr.update(visible=False), gr.update(visible=True), gr.update()
else:
# Update wan2.1 Diffusion model options
dit_choices = get_dit_choices(model_path_val, "wan2.1")
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
)
model_type_input.change(
fn=on_model_type_change,
inputs=[model_type_input, model_path_input],
outputs=[wan21_row, wan22_row, dit_path_input],
)
# Input parameters area
with gr.Accordion("📥 Input Parameters", open=True, elem_classes=["input-params"]):
# Image input (shown for i2v)
with gr.Row(visible=True) as image_input_row:
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
)
# Task type change event
def on_task_type_change(task_type):
return gr.update(visible=(task_type == "i2v"))
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[image_input_row],
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Describe the video content...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="Negative Prompt",
lines=3,
placeholder="What you don't want to appear in the video...",
max_lines=5,
value="Camera shake, bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
)
with gr.Column():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value="832x480",
label="Maximum Resolution",
)
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column():
default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=4,
info="Distill model inference steps default to 4.",
)
else:
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=40,
info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
)
# Dynamically update inference steps when model path changes
def update_infer_steps(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
return gr.update(minimum=1, maximum=100, value=4, interactive=True)
else:
return gr.update(minimum=1, maximum=100, value=40, interactive=True)
# Listen to model path changes
dit_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
# Set default CFG based on model class
# CFG scale factor: default to 1 for distill, otherwise 5
default_cfg_scale = 1 if default_is_distill else 5
# enable_cfg is not exposed to frontend, automatically set based on cfg_scale
# If cfg_scale == 1, then enable_cfg = False, otherwise enable_cfg = True
default_enable_cfg = False if default_cfg_scale == 1 else True
enable_cfg = gr.Checkbox(
label="Enable Classifier-Free Guidance",
value=default_enable_cfg,
visible=False, # Hidden, not exposed to frontend
)
with gr.Row():
sample_shift = gr.Slider(
label="Distribution Shift",
value=5,
minimum=0,
maximum=10,
step=1,
info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.",
)
cfg_scale = gr.Slider(
label="CFG Scale Factor",
minimum=1,
maximum=10,
step=1,
value=default_cfg_scale,
info="Controls the influence strength of the prompt. Higher values give more influence to the prompt. When value is 1, CFG is automatically disabled.",
)
# Update enable_cfg based on cfg_scale
def update_enable_cfg(cfg_scale_val):
"""Automatically set enable_cfg based on cfg_scale value"""
if cfg_scale_val == 1:
return gr.update(value=False)
else:
return gr.update(value=True)
# Dynamically update CFG scale factor and enable_cfg when model path changes
def update_cfg_scale(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
new_cfg_scale = 1
else:
new_cfg_scale = 5
new_enable_cfg = False if new_cfg_scale == 1 else True
return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
dit_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
cfg_scale.change(
fn=update_enable_cfg,
inputs=[cfg_scale],
outputs=[enable_cfg],
)
with gr.Row():
fps = gr.Slider(
label="Frames Per Second (FPS)",
minimum=8,
maximum=30,
step=1,
value=16,
info="Frames per second of the video. Higher FPS results in smoother videos.",
)
num_frames = gr.Slider(
label="Total Frames",
minimum=16,
maximum=120,
step=1,
value=81,
info="Total number of frames in the video. More frames result in longer videos.",
)
save_result_path = gr.Textbox(
label="Output Video Path",
value=generate_unique_filename(output_dir),
info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.",
visible=False, # Hide output path, auto-generated
)
with gr.Column(scale=4):
with gr.Accordion("📤 Generated Video", open=True, elem_classes=["output-video"]):
output_video = gr.Video(
label="",
height=600,
autoplay=True,
show_label=False,
)
infer_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg", elem_classes=["generate-btn"])
rope_chunk = gr.Checkbox(label="Chunked Rotary Position Embedding", value=False, visible=False)
rope_chunk_size = gr.Slider(label="Rotary Embedding Chunk Size", value=100, minimum=100, maximum=10000, step=100, visible=False)
unload_modules = gr.Checkbox(label="Unload Modules", value=False, visible=False)
clean_cuda_cache = gr.Checkbox(label="Clean CUDA Memory Cache", value=False, visible=False)
cpu_offload = gr.Checkbox(label="CPU Offloading", value=False, visible=False)
lazy_load = gr.Checkbox(label="Enable Lazy Loading", value=False, visible=False)
offload_granularity = gr.Dropdown(label="Dit Offload Granularity", choices=["block", "phase"], value="phase", visible=False)
t5_cpu_offload = gr.Checkbox(label="T5 CPU Offloading", value=False, visible=False)
clip_cpu_offload = gr.Checkbox(label="CLIP CPU Offloading", value=False, visible=False)
vae_cpu_offload = gr.Checkbox(label="VAE CPU Offloading", value=False, visible=False)
use_tiling_vae = gr.Checkbox(label="VAE Tiling Inference", value=False, visible=False)
resolution.change(
fn=auto_configure,
inputs=[resolution],
outputs=[
lazy_load,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
use_tiling_vae,
],
)
demo.load(
fn=lambda res: auto_configure(res),
inputs=[resolution],
outputs=[
lazy_load,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
use_tiling_vae,
],
)
infer_btn.click(
fn=run_inference,
inputs=[
prompt,
negative_prompt,
save_result_path,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_cfg,
cfg_scale,
fps,
use_tiling_vae,
lazy_load,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path,
],
outputs=output_video,
)
demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Lightweight Video Generation")
parser.add_argument("--model_path", type=str, required=True, help="Model folder path")
parser.add_argument("--server_port", type=int, default=7862, help="Server port")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="Output video save directory")
args = parser.parse_args()
global model_path, model_cls, output_dir
model_path = args.model_path
model_cls = "wan2.1"
output_dir = args.output_dir
main()
import argparse
import gc
import glob
import importlib.util
import json
import os
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
import random
from datetime import datetime
import gradio as gr
import psutil
import torch
from loguru import logger
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import get_default_config
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
logger.add(
"inference_logs.log",
rotation="100 MB",
encoding="utf-8",
enqueue=True,
backtrace=True,
diagnose=True,
)
MAX_NUMPY_SEED = 2**32 - 1
def scan_model_path_contents(model_path):
"""扫描 model_path 目录,返回可用的文件和子目录"""
if not model_path or not os.path.exists(model_path):
return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
dirs = []
files = []
safetensors_dirs = []
pth_files = []
try:
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
# 检查目录是否包含 safetensors 文件
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
except Exception as e:
logger.warning(f"扫描目录失败: {e}")
return {
"dirs": sorted(dirs),
"files": sorted(files),
"safetensors_dirs": sorted(safetensors_dirs),
"pth_files": sorted(pth_files),
}
def get_dit_choices(model_path, model_type="wan2.1"):
"""获取 Diffusion 模型可选项(根据模型类型筛选)"""
contents = scan_model_path_contents(model_path)
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
fp8_supported = is_fp8_supported_gpu()
if model_type == "wan2.1":
# wan2.1: 筛选包含 wan2.1 或 Wan2.1 的文件/目录
def is_valid(name):
name_lower = name.lower()
if "wan2.1" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
else:
# wan2.2: 筛选包含 wan2.2 或 Wan2.2 的文件/目录
def is_valid(name):
name_lower = name.lower()
if "wan2.2" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
# 筛选符合条件的目录和文件
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_high_noise_choices(model_path):
"""获取高噪模型可选项(包含 high_noise 的文件/目录)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_low_noise_choices(model_path):
"""获取低噪模型可选项(包含 low_noise 的文件/目录)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_t5_choices(model_path):
"""获取 T5 模型可选项(.pth 或 .safetensors 文件,包含 t5 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_clip_choices(model_path):
"""获取 CLIP 模型可选项(.pth 或 .safetensors 文件,包含 clip 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_vae_choices(model_path):
"""获取 VAE 模型可选项(.pth 或 .safetensors 文件,包含 vae/VAE/tae 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def detect_quant_scheme(model_name):
"""根据模型名字自动检测量化精度
- 如果模型名字包含 "int8" → "int8"
- 如果模型名字包含 "fp8" 且设备支持 → "fp8"
- 否则返回 None(表示不使用量化)
"""
if not model_name:
return None
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
# 设备不支持fp8,返回None(使用默认精度)
return None
return None
def update_model_path_options(model_path, model_type="wan2.1"):
"""当 model_path 或 model_type 改变时,更新所有模型路径选择器"""
dit_choices = get_dit_choices(model_path, model_type)
high_noise_choices = get_high_noise_choices(model_path)
low_noise_choices = get_low_noise_choices(model_path)
t5_choices = get_t5_choices(model_path)
clip_choices = get_clip_choices(model_path)
vae_choices = get_vae_choices(model_path)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
)
def generate_random_seed():
return random.randint(0, MAX_NUMPY_SEED)
def is_module_installed(module_name):
try:
spec = importlib.util.find_spec(module_name)
return spec is not None
except ModuleNotFoundError:
return False
def get_available_quant_ops():
available_ops = []
vllm_installed = is_module_installed("vllm")
if vllm_installed:
available_ops.append(("vllm", True))
else:
available_ops.append(("vllm", False))
sgl_installed = is_module_installed("sgl_kernel")
if sgl_installed:
available_ops.append(("sgl", True))
else:
available_ops.append(("sgl", False))
q8f_installed = is_module_installed("q8_kernels")
if q8f_installed:
available_ops.append(("q8f", True))
else:
available_ops.append(("q8f", False))
return available_ops
def get_available_attn_ops():
available_ops = []
vllm_installed = is_module_installed("flash_attn")
if vllm_installed:
available_ops.append(("flash_attn2", True))
else:
available_ops.append(("flash_attn2", False))
sgl_installed = is_module_installed("flash_attn_interface")
if sgl_installed:
available_ops.append(("flash_attn3", True))
else:
available_ops.append(("flash_attn3", False))
sage_installed = is_module_installed("sageattention")
if sage_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops
def get_gpu_memory(gpu_idx=0):
if not torch.cuda.is_available():
return 0
try:
with torch.cuda.device(gpu_idx):
memory_info = torch.cuda.mem_get_info()
total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
return total_memory
except Exception as e:
logger.warning(f"获取GPU内存失败: {e}")
return 0
def get_cpu_memory():
available_bytes = psutil.virtual_memory().available
return available_bytes / 1024**3
def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
import psutil
if hasattr(psutil, "virtual_memory"):
if os.name == "posix":
try:
os.system("sync")
except: # noqa
pass
except: # noqa
pass
def generate_unique_filename(output_dir):
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join(output_dir, f"{timestamp}.mp4")
def is_fp8_supported_gpu():
if not torch.cuda.is_available():
return False
compute_capability = torch.cuda.get_device_capability(0)
major, minor = compute_capability
return (major == 8 and minor == 9) or (major >= 9)
def is_ada_architecture_gpu():
if not torch.cuda.is_available():
return False
try:
gpu_name = torch.cuda.get_device_name(0).upper()
ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
return any(keyword in gpu_name for keyword in ada_keywords)
except Exception as e:
logger.warning(f"Failed to get GPU name: {e}")
return False
def get_quantization_options(model_path):
"""根据model_path动态获取量化选项"""
import os
# 检查子目录
subdirs = ["original", "fp8", "int8"]
has_subdirs = {subdir: os.path.exists(os.path.join(model_path, subdir)) for subdir in subdirs}
# 检查根目录下的原始文件
t5_bf16_exists = os.path.exists(os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth"))
clip_fp16_exists = os.path.exists(os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"))
# 生成选项
def get_choices(has_subdirs, original_type, fp8_type, int8_type, fallback_type, has_original_file=False):
choices = []
if has_subdirs["original"]:
choices.append(original_type)
if has_subdirs["fp8"]:
choices.append(fp8_type)
if has_subdirs["int8"]:
choices.append(int8_type)
# 如果没有子目录但有原始文件,添加原始类型
if has_original_file:
if not choices or "original" not in choices:
choices.append(original_type)
# 如果没有任何选项,使用默认值
if not choices:
choices = [fallback_type]
return choices, choices[0]
# DIT选项
dit_choices, dit_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16")
# T5选项 - 检查是否有原始文件
t5_choices, t5_default = get_choices(has_subdirs, "bf16", "fp8", "int8", "bf16", t5_bf16_exists)
# CLIP选项 - 检查是否有原始文件
clip_choices, clip_default = get_choices(has_subdirs, "fp16", "fp8", "int8", "fp16", clip_fp16_exists)
return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
def determine_model_cls(model_type, dit_name, high_noise_name):
"""根据模型类型和文件名确定 model_cls"""
# 确定要检查的文件名
if model_type == "wan2.1":
check_name = dit_name.lower() if dit_name else ""
is_distill = "4step" in check_name
return "wan2.1_distill" if is_distill else "wan2.1"
else:
# wan2.2
check_name = high_noise_name.lower() if high_noise_name else ""
is_distill = "4step" in check_name
return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
global_runner = None
current_config = None
cur_dit_path = None
cur_t5_path = None
cur_clip_path = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
for op_name, is_installed in available_quant_ops:
status_text = "✅ 已安装" if is_installed else "❌ 未安装"
display_text = f"{op_name} ({status_text})"
quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops()
# 优先级顺序
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
# 按优先级排序,已安装的在前,未安装的在后
attn_op_choices = []
attn_op_dict = dict(available_attn_ops)
# 先添加已安装的(按优先级)
for op_name in attn_priority:
if op_name in attn_op_dict and attn_op_dict[op_name]:
status_text = "✅ 已安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# 再添加未安装的(按优先级)
for op_name in attn_priority:
if op_name in attn_op_dict and not attn_op_dict[op_name]:
status_text = "❌ 未安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# 添加其他不在优先级列表中的算子(已安装的在前)
other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # 已安装的在前
status_text = "✅ 已安装" if is_installed else "❌ 未安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
def run_inference(
prompt,
negative_prompt,
save_result_path,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_cfg,
cfg_scale,
fps,
use_tiling_vae,
lazy_load,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path=None,
):
cleanup_memory()
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path, model_cls
global cur_dit_path, cur_t5_path, cur_clip_path
task = task_type_input
model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
logger.info(f"自动确定 model_cls: {model_cls} (模型类型: {model_type_input})")
if model_type_input == "wan2.1":
dit_quant_detected = detect_quant_scheme(dit_path_input)
else:
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
t5_quant_detected = detect_quant_scheme(t5_path_input)
clip_quant_detected = detect_quant_scheme(clip_path_input)
logger.info(f"自动检测量化精度 - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
if model_path_input and model_path_input.strip():
model_path = model_path_input.strip()
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f)
else:
model_config = {}
save_result_path = generate_unique_filename(output_dir)
is_dit_quant = dit_quant_detected != "bf16"
is_t5_quant = t5_quant_detected != "bf16"
is_clip_quant = clip_quant_detected != "fp16"
dit_quantized_ckpt = None
dit_original_ckpt = None
high_noise_quantized_ckpt = None
low_noise_quantized_ckpt = None
high_noise_original_ckpt = None
low_noise_original_ckpt = None
if is_dit_quant:
dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
if "wan2.1" in model_cls:
dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
else:
dit_quantized_ckpt = "Default"
if "wan2.1" in model_cls:
dit_original_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
# 使用前端选择的 T5 路径
if is_t5_quant:
t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
t5_original_ckpt = None
else:
t5_quantized_ckpt = None
t5_quant_scheme = None
t5_original_ckpt = os.path.join(model_path, t5_path_input)
# 使用前端选择的 CLIP 路径
if is_clip_quant:
clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
clip_original_ckpt = None
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
clip_original_ckpt = os.path.join(model_path, clip_path_input)
if model_type_input == "wan2.1":
current_dit_path = dit_path_input
else:
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
current_t5_path = t5_path_input
current_clip_path = clip_path_input
needs_reinit = (
lazy_load
or unload_modules
or global_runner is None
or current_config is None
or cur_dit_path is None
or cur_dit_path != current_dit_path
or cur_t5_path is None
or cur_t5_path != current_t5_path
or cur_clip_path is None
or cur_clip_path != current_clip_path
)
if cfg_scale == 1:
enable_cfg = False
else:
enable_cfg = True
vae_name_lower = vae_path_input.lower() if vae_path_input else ""
use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
use_lightvae = "lightvae" in vae_name_lower
need_scaled = "lighttae" in vae_name_lower
logger.info(f"VAE 配置 - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
config_graio = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]),
"target_height": int(resolution.split("x")[1]),
"self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale,
"sample_shift": sample_shift,
"fps": fps,
"feature_caching": "NoCaching",
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
"t5_cpu_offload": t5_cpu_offload,
"clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"dit_quantized": is_dit_quant,
"dit_quant_scheme": dit_quant_scheme,
"dit_quantized_ckpt": dit_quantized_ckpt,
"dit_original_ckpt": dit_original_ckpt,
"high_noise_quantized_ckpt": high_noise_quantized_ckpt,
"low_noise_quantized_ckpt": low_noise_quantized_ckpt,
"high_noise_original_ckpt": high_noise_original_ckpt,
"low_noise_original_ckpt": low_noise_original_ckpt,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quantized_ckpt,
"t5_quant_scheme": t5_quant_scheme,
"clip_original_ckpt": clip_original_ckpt,
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quantized_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": os.path.join(model_path, vae_path_input),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"use_lightvae": use_lightvae,
"need_scaled": need_scaled,
"lazy_load": lazy_load,
"rope_chunk": rope_chunk,
"rope_chunk_size": rope_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
}
args = argparse.Namespace(
model_cls=model_cls,
seed=seed,
task=task,
model_path=model_path,
prompt_enhancer=None,
prompt=prompt,
negative_prompt=negative_prompt,
image_path=image_path,
save_result_path=save_result_path,
return_result_tensor=False,
)
config = get_default_config()
config.update({k: v for k, v in vars(args).items()})
config.update(model_config)
config.update(config_graio)
logger.info(f"使用模型: {model_path}")
logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
# Initialize or reuse the runner
runner = global_runner
if needs_reinit:
if runner is not None:
del runner
torch.cuda.empty_cache()
gc.collect()
from lightx2v.infer import init_runner # noqa
runner = init_runner(config)
input_info = set_input_info(args)
current_config = config
cur_dit_path = current_dit_path
cur_t5_path = current_t5_path
cur_clip_path = current_clip_path
if not lazy_load:
global_runner = runner
else:
runner.config = config
runner.run_pipeline(input_info)
cleanup_memory()
return save_result_path
def handle_lazy_load_change(lazy_load_enabled):
"""Handle lazy_load checkbox change to automatically enable unload_modules"""
return gr.update(value=lazy_load_enabled)
def auto_configure(resolution):
"""根据机器配置和分辨率自动设置推理选项"""
default_config = {
"lazy_load_val": False,
"rope_chunk_val": False,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": False,
"cpu_offload_val": False,
"offload_granularity_val": "block",
"t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False,
"attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1],
"use_tiling_vae_val": False,
}
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["vllm", "sgl", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
default_config["attention_type_val"] = dict(attn_op_choices)[op]
break
for op in quant_op_priority:
if dict(available_quant_ops).get(op):
default_config["quant_op_val"] = dict(quant_op_choices)[op]
break
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
res = "720p"
elif resolution in [
"960x544",
"544x960",
]:
res = "540p"
else:
res = "480p"
if res == "720p":
gpu_rules = [
(80, {}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"use_tiling_vae_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": True,
},
),
]
else:
gpu_rules = [
(80, {}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
]
cpu_rules = [
(128, {}),
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
default_config.update(updates)
break
for threshold, updates in cpu_rules:
if cpu_memory >= threshold:
default_config.update(updates)
break
return (
gr.update(value=default_config["lazy_load_val"]),
gr.update(value=default_config["rope_chunk_val"]),
gr.update(value=default_config["rope_chunk_size_val"]),
gr.update(value=default_config["clean_cuda_cache_val"]),
gr.update(value=default_config["cpu_offload_val"]),
gr.update(value=default_config["offload_granularity_val"]),
gr.update(value=default_config["t5_cpu_offload_val"]),
gr.update(value=default_config["clip_cpu_offload_val"]),
gr.update(value=default_config["vae_cpu_offload_val"]),
gr.update(value=default_config["unload_modules_val"]),
gr.update(value=default_config["attention_type_val"]),
gr.update(value=default_config["quant_op_val"]),
gr.update(value=default_config["use_tiling_vae_val"]),
)
css = """
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
/* 模型配置区域样式 */
.model-config {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
}
/* 输入参数区域样式 */
.input-params {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
}
/* 输出视频区域样式 */
.output-video {
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
}
/* 生成按钮样式 */
.generate-btn {
width: 100%;
margin-top: 20px;
padding: 15px 30px !important;
font-size: 18px !important;
font-weight: bold !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
border-radius: 10px !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
transition: all 0.3s ease !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
/* Accordion 标题样式 */
.model-config .gr-accordion-header,
.input-params .gr-accordion-header,
.output-video .gr-accordion-header {
font-size: 20px !important;
font-weight: bold !important;
padding: 15px !important;
}
/* 优化间距 */
.gr-row {
margin-bottom: 15px;
}
/* 视频播放器样式 */
.output-video video {
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
"""
def main():
with gr.Blocks(title="Lightx2v (轻量级视频推理和生成引擎)") as demo:
gr.Markdown(f"# 🎬 LightX2V 视频生成器")
gr.HTML(f"<style>{css}</style>")
# 主布局:左右分栏
with gr.Row():
# 左侧:配置和输入区域
with gr.Column(scale=5):
# 模型配置区域
with gr.Accordion("🗂️ 模型配置", open=True, elem_classes=["model-config"]):
# FP8 支持提示
if not is_fp8_supported_gpu():
gr.Markdown("⚠️ **您的设备不支持fp8推理**,已自动隐藏包含fp8的模型选项。")
# 隐藏的状态组件
model_path_input = gr.Textbox(value=model_path, visible=False)
# 模型类型 + 任务类型
with gr.Row():
model_type_input = gr.Radio(
label="模型类型",
choices=["wan2.1", "wan2.2"],
value="wan2.1",
info="wan2.2 需要分别指定高噪模型和低噪模型",
)
task_type_input = gr.Radio(
label="任务类型",
choices=["i2v", "t2v"],
value="i2v",
info="i2v: 图生视频, t2v: 文生视频",
)
# wan2.1:Diffusion模型(单独一行)
with gr.Row() as wan21_row:
dit_path_input = gr.Dropdown(
label="🎨 Diffusion模型",
choices=get_dit_choices(model_path, "wan2.1"),
value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
allow_custom_value=True,
visible=True,
)
# wan2.2 专用:高噪模型 + 低噪模型(默认隐藏)
with gr.Row(visible=False) as wan22_row:
high_noise_path_input = gr.Dropdown(
label="🔊 高噪模型",
choices=get_high_noise_choices(model_path),
value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
allow_custom_value=True,
)
low_noise_path_input = gr.Dropdown(
label="🔇 低噪模型",
choices=get_low_noise_choices(model_path),
value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
allow_custom_value=True,
)
# 文本编码器(单独一行)
with gr.Row():
t5_path_input = gr.Dropdown(
label="📝 文本编码器",
choices=get_t5_choices(model_path),
value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
allow_custom_value=True,
)
# 图像编码器 + VAE解码器
with gr.Row():
clip_path_input = gr.Dropdown(
label="🖼️ 图像编码器",
choices=get_clip_choices(model_path),
value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
allow_custom_value=True,
)
vae_path_input = gr.Dropdown(
label="🎞️ VAE解码器",
choices=get_vae_choices(model_path),
value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
allow_custom_value=True,
)
# 注意力算子和量化矩阵乘法算子
with gr.Row():
attention_type = gr.Dropdown(
label="⚡ 注意力算子",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1] if attn_op_choices else "",
info="使用适当的注意力算子加速推理",
)
quant_op = gr.Dropdown(
label="量化矩阵乘法算子",
choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1],
info="选择量化矩阵乘法算子以加速推理",
interactive=True,
)
# 判断模型是否是 distill 版本
def is_distill_model(model_type, dit_path, high_noise_path):
"""根据模型类型和路径判断是否是 distill 版本"""
if model_type == "wan2.1":
check_name = dit_path.lower() if dit_path else ""
else:
check_name = high_noise_path.lower() if high_noise_path else ""
return "4step" in check_name
# 模型类型切换事件
def on_model_type_change(model_type, model_path_val):
if model_type == "wan2.2":
return gr.update(visible=False), gr.update(visible=True), gr.update()
else:
# 更新 wan2.1 的 Diffusion 模型选项
dit_choices = get_dit_choices(model_path_val, "wan2.1")
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
)
model_type_input.change(
fn=on_model_type_change,
inputs=[model_type_input, model_path_input],
outputs=[wan21_row, wan22_row, dit_path_input],
)
# 输入参数区域
with gr.Accordion("📥 输入参数", open=True, elem_classes=["input-params"]):
# 图片输入(i2v 时显示)
with gr.Row(visible=True) as image_input_row:
image_path = gr.Image(
label="输入图像",
type="filepath",
height=300,
interactive=True,
)
# 任务类型切换事件
def on_task_type_change(task_type):
return gr.update(visible=(task_type == "i2v"))
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[image_input_row],
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="提示词",
lines=3,
placeholder="描述视频内容...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="负向提示词",
lines=3,
placeholder="不希望出现在视频中的内容...",
max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
)
with gr.Column():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value="832x480",
label="最大分辨率",
)
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column():
default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=4,
info="蒸馏模型推理步数默认为4。",
)
else:
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=40,
info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
)
# 当模型路径改变时,动态更新推理步数
def update_infer_steps(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
return gr.update(minimum=1, maximum=100, value=4, interactive=True)
else:
return gr.update(minimum=1, maximum=100, value=40, interactive=True)
# 监听模型路径变化
dit_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
# 根据模型类别设置默认CFG
# CFG缩放因子:distill 时默认为 1,否则默认为 5
default_cfg_scale = 1 if default_is_distill else 5
# enable_cfg 不暴露到前端,根据 cfg_scale 自动设置
# 如果 cfg_scale == 1,则 enable_cfg = False,否则 enable_cfg = True
default_enable_cfg = False if default_cfg_scale == 1 else True
enable_cfg = gr.Checkbox(
label="启用无分类器引导",
value=default_enable_cfg,
visible=False, # 隐藏,不暴露到前端
)
with gr.Row():
sample_shift = gr.Slider(
label="分布偏移",
value=5,
minimum=0,
maximum=10,
step=1,
info="控制样本分布偏移的程度。值越大表示偏移越明显。",
)
cfg_scale = gr.Slider(
label="CFG缩放因子",
minimum=1,
maximum=10,
step=1,
value=default_cfg_scale,
info="控制提示词的影响强度。值越高,提示词的影响越大。当值为1时,自动禁用CFG。",
)
# 根据 cfg_scale 更新 enable_cfg
def update_enable_cfg(cfg_scale_val):
"""根据 cfg_scale 的值自动设置 enable_cfg"""
if cfg_scale_val == 1:
return gr.update(value=False)
else:
return gr.update(value=True)
# 当模型路径改变时,动态更新 CFG 缩放因子和 enable_cfg
def update_cfg_scale(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
new_cfg_scale = 1
else:
new_cfg_scale = 5
new_enable_cfg = False if new_cfg_scale == 1 else True
return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
dit_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
cfg_scale.change(
fn=update_enable_cfg,
inputs=[cfg_scale],
outputs=[enable_cfg],
)
with gr.Row():
fps = gr.Slider(
label="每秒帧数(FPS)",
minimum=8,
maximum=30,
step=1,
value=16,
info="视频的每秒帧数。较高的FPS会产生更流畅的视频。",
)
num_frames = gr.Slider(
label="总帧数",
minimum=16,
maximum=120,
step=1,
value=81,
info="视频中的总帧数。更多帧数会产生更长的视频。",
)
save_result_path = gr.Textbox(
label="输出视频路径",
value=generate_unique_filename(output_dir),
info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
visible=False, # 隐藏输出路径,自动生成
)
with gr.Column(scale=4):
with gr.Accordion("📤 生成的视频", open=True, elem_classes=["output-video"]):
output_video = gr.Video(
label="",
height=600,
autoplay=True,
show_label=False,
)
infer_btn = gr.Button("🎬 生成视频", variant="primary", size="lg", elem_classes=["generate-btn"])
rope_chunk = gr.Checkbox(label="分块旋转位置编码", value=False, visible=False)
rope_chunk_size = gr.Slider(label="旋转编码块大小", value=100, minimum=100, maximum=10000, step=100, visible=False)
unload_modules = gr.Checkbox(label="卸载模块", value=False, visible=False)
clean_cuda_cache = gr.Checkbox(label="清理CUDA内存缓存", value=False, visible=False)
cpu_offload = gr.Checkbox(label="CPU卸载", value=False, visible=False)
lazy_load = gr.Checkbox(label="启用延迟加载", value=False, visible=False)
offload_granularity = gr.Dropdown(label="Dit卸载粒度", choices=["block", "phase"], value="phase", visible=False)
t5_cpu_offload = gr.Checkbox(label="T5 CPU卸载", value=False, visible=False)
clip_cpu_offload = gr.Checkbox(label="CLIP CPU卸载", value=False, visible=False)
vae_cpu_offload = gr.Checkbox(label="VAE CPU卸载", value=False, visible=False)
use_tiling_vae = gr.Checkbox(label="VAE分块推理", value=False, visible=False)
resolution.change(
fn=auto_configure,
inputs=[resolution],
outputs=[
lazy_load,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
use_tiling_vae,
],
)
demo.load(
fn=lambda res: auto_configure(res),
inputs=[resolution],
outputs=[
lazy_load,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
use_tiling_vae,
],
)
infer_btn.click(
fn=run_inference,
inputs=[
prompt,
negative_prompt,
save_result_path,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_cfg,
cfg_scale,
fps,
use_tiling_vae,
lazy_load,
cpu_offload,
offload_granularity,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
attention_type,
quant_op,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path,
],
outputs=output_video,
)
demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="轻量级视频生成")
parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径")
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录")
args = parser.parse_args()
global model_path, model_cls, output_dir
model_path = args.model_path
model_cls = "wan2.1"
output_dir = args.output_dir
main()
#!/bin/bash
# Lightx2v Gradio Demo Startup Script
# Supports both Image-to-Video (i2v) and Text-to-Video (t2v) modes
# ==================== Configuration Area ====================
# ⚠️ Important: Please modify the following paths according to your actual environment
# 🚨 Storage Performance Tips 🚨
# 💾 Strongly recommend storing model files on SSD solid-state drives!
# 📈 SSD can significantly improve model loading speed and inference performance
# 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
# Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/data/video_gen/lightx2v_debug/LightX2V
# Model path configuration
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
model_path=/models/
# Server configuration
server_name="0.0.0.0"
server_port=8033
# Output directory configuration
output_dir="./outputs"
# GPU configuration
gpu_id=0
# ==================== Environment Variables Setup ====================
export CUDA_VISIBLE_DEVICES=$gpu_id
export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# ==================== Parameter Parsing ====================
# Default interface language
lang="zh"
# 解析命令行参数
while [[ $# -gt 0 ]]; do
case $1 in
--lang)
lang="$2"
shift 2
;;
--port)
server_port="$2"
shift 2
;;
--gpu)
gpu_id="$2"
export CUDA_VISIBLE_DEVICES=$gpu_id
shift 2
;;
--output_dir)
output_dir="$2"
shift 2
;;
--model_path)
model_path="$2"
shift 2
;;
--help)
echo "🎬 Lightx2v Gradio Demo Startup Script"
echo "=========================================="
echo "Usage: $0 [options]"
echo ""
echo "📋 Available options:"
echo " --lang zh|en Interface language (default: zh)"
echo " zh: Chinese interface"
echo " en: English interface"
echo " --port PORT Server port (default: 8032)"
echo " --gpu GPU_ID GPU device ID (default: 0)"
echo " --model_path PATH Model path (default: configured in script)"
echo " --output_dir DIR Output video save directory (default: ./outputs)"
echo " --help Show this help message"
echo ""
echo "📝 Notes:"
echo " - Task type (i2v/t2v) and model type are selected in the web UI"
echo " - Model class is auto-detected based on selected diffusion model"
echo " - Edit script to configure model paths before first use"
echo " - Ensure required Python dependencies are installed"
echo " - Recommended to use GPU with 8GB+ VRAM"
echo " - 🚨 Strongly recommend storing models on SSD for better performance"
exit 0
;;
*)
echo "Unknown parameter: $1"
echo "Use --help to see help information"
exit 1
;;
esac
done
# ==================== Parameter Validation ====================
if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
echo "Error: Language must be 'zh' or 'en'"
exit 1
fi
# Check if model path exists
if [[ ! -d "$model_path" ]]; then
echo "❌ Error: Model path does not exist"
echo "📁 Path: $model_path"
echo "🔧 Solutions:"
echo " 1. Check model path configuration in script"
echo " 2. Ensure model files are properly downloaded"
echo " 3. Verify path permissions are correct"
echo " 4. 💾 Recommend storing models on SSD for faster loading"
exit 1
fi
# Select demo file based on language
if [[ "$lang" == "zh" ]]; then
demo_file="gradio_demo_zh.py"
echo "🌏 Using Chinese interface"
else
demo_file="gradio_demo.py"
echo "🌏 Using English interface"
fi
# Check if demo file exists
if [[ ! -f "$demo_file" ]]; then
echo "❌ Error: Demo file does not exist"
echo "📄 File: $demo_file"
echo "🔧 Solutions:"
echo " 1. Ensure script is run in the correct directory"
echo " 2. Check if file has been renamed or moved"
echo " 3. Re-clone or download project files"
exit 1
fi
# ==================== System Information Display ====================
echo "=========================================="
echo "🚀 Lightx2v Gradio Demo Starting..."
echo "=========================================="
echo "📁 Project path: $lightx2v_path"
echo "🤖 Model path: $model_path"
echo "🌏 Interface language: $lang"
echo "🖥️ GPU device: $gpu_id"
echo "🌐 Server address: $server_name:$server_port"
echo "📁 Output directory: $output_dir"
echo "📝 Note: Task type and model class are selected in web UI"
echo "=========================================="
# Display system resource information
echo "💻 System resource information:"
free -h | grep -E "Mem|Swap"
echo ""
# Display GPU information
if command -v nvidia-smi &> /dev/null; then
echo "🎮 GPU information:"
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits | head -1
echo ""
fi
# ==================== Start Demo ====================
echo "🎬 Starting Gradio demo..."
echo "📱 Please access in browser: http://$server_name:$server_port"
echo "⏹️ Press Ctrl+C to stop service"
echo "🔄 First startup may take several minutes to load resources..."
echo "=========================================="
# Start Python demo
python $demo_file \
--model_path "$model_path" \
--server_name "$server_name" \
--server_port "$server_port" \
--output_dir "$output_dir"
# Display final system resource usage
echo ""
echo "=========================================="
echo "📊 Final system resource usage:"
free -h | grep -E "Mem|Swap"
@echo off
chcp 65001 >nul
echo 🎬 LightX2V Gradio Windows Startup Script
echo ==========================================
REM ==================== Configuration Area ====================
REM ⚠️ Important: Please modify the following paths according to your actual environment
REM 🚨 Storage Performance Tips 🚨
REM 💾 Strongly recommend storing model files on SSD solid-state drives!
REM 📈 SSD can significantly improve model loading speed and inference performance
REM 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
REM LightX2V project root directory path
REM Example: D:\LightX2V
set lightx2v_path=/path/to/LightX2V
REM Model path configuration
REM Model root directory path
REM Example: D:\models\LightX2V
set model_path=/path/to/LightX2V
REM Server configuration
set server_name=127.0.0.1
set server_port=8032
REM Output directory configuration
set output_dir=./outputs
REM GPU configuration
set gpu_id=0
REM ==================== Environment Variables Setup ====================
set CUDA_VISIBLE_DEVICES=%gpu_id%
set PYTHONPATH=%lightx2v_path%;%PYTHONPATH%
set PROFILING_DEBUG_LEVEL=2
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
REM ==================== Parameter Parsing ====================
REM Default interface language
set lang=zh
REM Parse command line arguments
:parse_args
if "%1"=="" goto :end_parse
if "%1"=="--lang" (
set lang=%2
shift
shift
goto :parse_args
)
if "%1"=="--port" (
set server_port=%2
shift
shift
goto :parse_args
)
if "%1"=="--gpu" (
set gpu_id=%2
set CUDA_VISIBLE_DEVICES=%gpu_id%
shift
shift
goto :parse_args
)
if "%1"=="--output_dir" (
set output_dir=%2
shift
shift
goto :parse_args
)
if "%1"=="--help" (
echo 🎬 LightX2V Gradio Windows Startup Script
echo ==========================================
echo Usage: %0 [options]
echo.
echo 📋 Available options:
echo --lang zh^|en Interface language (default: zh)
echo zh: Chinese interface
echo en: English interface
echo --port PORT Server port (default: 8032)
echo --gpu GPU_ID GPU device ID (default: 0)
echo --output_dir OUTPUT_DIR
echo Output video save directory (default: ./outputs)
echo --help Show this help message
echo.
echo 🚀 Usage examples:
echo %0 # Default startup
echo %0 --lang zh --port 8032 # Start with specified parameters
echo %0 --lang en --port 7860 # English interface
echo %0 --gpu 1 --port 8032 # Use GPU 1
echo %0 --output_dir ./custom_output # Use custom output directory
echo.
echo 📝 Notes:
echo - Edit script to configure model path before first use
echo - Ensure required Python dependencies are installed
echo - Recommended to use GPU with 8GB+ VRAM
echo - 🚨 Strongly recommend storing models on SSD for better performance
pause
exit /b 0
)
echo Unknown parameter: %1
echo Use --help to see help information
pause
exit /b 1
:end_parse
REM ==================== Parameter Validation ====================
if "%lang%"=="zh" goto :valid_lang
if "%lang%"=="en" goto :valid_lang
echo Error: Language must be 'zh' or 'en'
pause
exit /b 1
:valid_lang
REM Check if model path exists
if not exist "%model_path%" (
echoError: Model path does not exist
echo 📁 Path: %model_path%
echo 🔧 Solutions:
echo 1. Check model path configuration in script
echo 2. Ensure model files are properly downloaded
echo 3. Verify path permissions are correct
echo 4. 💾 Recommend storing models on SSD for faster loading
pause
exit /b 1
)
REM Select demo file based on language
if "%lang%"=="zh" (
set demo_file=gradio_demo_zh.py
echo 🌏 Using Chinese interface
) else (
set demo_file=gradio_demo.py
echo 🌏 Using English interface
)
REM Check if demo file exists
if not exist "%demo_file%" (
echoError: Demo file does not exist
echo 📄 File: %demo_file%
echo 🔧 Solutions:
echo 1. Ensure script is run in the correct directory
echo 2. Check if file has been renamed or moved
echo 3. Re-clone or download project files
pause
exit /b 1
)
REM ==================== System Information Display ====================
echo ==========================================
echo 🚀 LightX2V Gradio Starting...
echo ==========================================
echo 📁 Project path: %lightx2v_path%
echo 🤖 Model path: %model_path%
echo 🌏 Interface language: %lang%
echo 🖥️ GPU device: %gpu_id%
echo 🌐 Server address: %server_name%:%server_port%
echo 📁 Output directory: %output_dir%
echo ==========================================
REM Display system resource information
echo 💻 System resource information:
wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table
REM Display GPU information
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits 2>nul
if errorlevel 1 (
echo 🎮 GPU information: Unable to get GPU info
) else (
echo 🎮 GPU information:
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits
)
REM ==================== Start Demo ====================
echo 🎬 Starting Gradio demo...
echo 📱 Please access in browser: http://%server_name%:%server_port%
echo ⏹️ Press Ctrl+C to stop service
echo 🔄 First startup may take several minutes to load resources...
echo ==========================================
REM Start Python demo
python %demo_file% ^
--model_path "%model_path%" ^
--server_name %server_name% ^
--server_port %server_port% ^
--output_dir "%output_dir%"
REM Display final system resource usage
echo.
echo ==========================================
echo 📊 Final system resource usage:
wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table
pause
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment