"vscode:/vscode.git/clone" did not exist on "3e8135ceb22da3f95d97ab5ab4f4a86ec888854f"
Commit e2778d0d authored by litzh's avatar litzh
Browse files

Initial commit

parents
Pipeline #3370 canceled with stages
---
name: Bug Report
about: Use this template to report bugs in the project.
title: "[Bug] "
labels: bug
assignees: ''
---
### Description
Briefly describe the bug you encountered.
### Steps to Reproduce
1. First step of the operation.
2. Second step of the operation.
3. ...
### Expected Result
Describe the normal behavior you expected.
### Actual Result
Describe the abnormal situation that actually occurred.
### Environment Information
- Operating System: [e.g., Ubuntu 22.04]
- Commit ID: [Version of the project]
### Log Information
Please provide relevant error logs or debugging information.
### Additional Information
If there is any other information that can help solve the problem, please add it here.
name: lint
on:
pull_request:
push:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Cache Python dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install pre-commit hook
run: |
pip install pre-commit ruff
- name: Check pre-commit config file
run: |
if [ ! -f ".pre-commit-config.yaml" ]; then
echo "Error: .pre-commit-config.yaml not found."
exit 1
fi
- name: Linting
run: |
echo "Running pre-commit on all files..."
pre-commit run --all-files || {
echo "Linting failed. Please check the above output for details."
exit 1
}
*.pth
*.pt
*.onnx
*.pk
*.model
*.zip
*.tar
*.pyc
*.log
*.o
*.so
*.a
*.exe
*.out
.idea
**.DS_Store**
**/__pycache__/**
**.swp
.vscode/
.env
.log
*.pid
*.ipynb*
*.mp4
build/
dist/
.cache/
server_cache/
app/.gradio/
*.pkl
save_results/*
*.egg-info/
# Follow https://verdantfox.com/blog/how-to-use-git-pre-commit-hooks-the-hard-way-and-the-easy-way
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
- id: ruff-format
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-added-large-files
args: ['--maxkb=3000'] # Allow files up to 3MB
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
<div align="center" style="font-family: charter;">
<h1>⚡️ LightX2V:<br> Light Video Generation Inference Framework</h1>
<img alt="logo" src="assets/img_lightx2v.png" width=75%></img>
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/ModelTC/lightx2v)
[![Doc](https://img.shields.io/badge/docs-English-99cc2)](https://lightx2v-en.readthedocs.io/en/latest)
[![Doc](https://img.shields.io/badge/文档-中文-99cc2)](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest)
[![Papers](https://img.shields.io/badge/论文集-中文-99cc2)](https://lightx2v-papers-zhcn.readthedocs.io/zh-cn/latest)
[![Docker](https://img.shields.io/badge/Docker-2496ED?style=flat&logo=docker&logoColor=white)](https://hub.docker.com/r/lightx2v/lightx2v/tags)
**\[ English | [中文](README_zh.md) \]**
</div>
--------------------------------------------------------------------------------
**LightX2V** is an advanced lightweight image/video generation inference framework engineered to deliver efficient, high-performance image/video synthesis solutions. This unified platform integrates multiple state-of-the-art image/video generation techniques, supporting diverse generation tasks including text-to-video (T2V), image-to-video (I2V), text-to-image (T2I), image-editing (I2I). **X2V represents the transformation of different input modalities (X, such as text or images) into vision output (Vision)**.
> 🌐 **Try it online now!** Experience LightX2V without installation: **[LightX2V Online Service](https://x2v.light-ai.top/login)** - Free, lightweight, and fast AI digital human video generation platform.
> 👋 **Join our WeChat group! LightX2V Rotbot WeChat ID: random42seed**
## 🧾 Community Code Contribution Guidelines
Before submitting, please ensure that the code format conforms to the project standard. You can use the following execution command to ensure the consistency of project code format.
```bash
pip install ruff pre-commit
pre-commit run --all-files
```
Besides the contributions from the LightX2V team, we have received contributions from some community developers, including but not limited to:
- [triple-Mu](https://github.com/triple-Mu)
- [vivienfanghuagood](https://github.com/vivienfanghuagood)
- [yeahdongcn](https://github.com/yeahdongcn)
- [kikidouloveme79](https://github.com/kikidouloveme79)
## :fire: Latest News
- **January 20, 2026:** 🚀 We support the [LTX-2](https://huggingface.co/Lightricks/LTX-2) audio-video generation model, featuring CFG parallelism, block-level offload, and FP8 per-tensor quantization. Usage examples can be found in [examples/ltx2](https://github.com/ModelTC/LightX2V/tree/main/examples/ltx2) and [scripts/ltx2](https://github.com/ModelTC/LightX2V/tree/main/scripts/ltx2).
- **January 6, 2026:** 🚀 We updated the 8-step CFG/step-distilled models for [Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512) and [Qwen/Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511). You can download the corresponding weights from [Qwen-Image-Edit-2511-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning) and [Qwen-Image-2512-Lightning](https://huggingface.co/lightx2v/Qwen-Image-2512-Lightning) for use. Usage tutorials can be found [here](https://github.com/ModelTC/LightX2V/tree/main/examples/qwen_image).
- **January 6, 2026:** 🚀 Supported deployment on Enflame S60 (GCU).
- **December 31, 2025:** 🚀 We support the [Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512) text-to-image model since Day 0. Our [HuggingFace](https://huggingface.co/lightx2v/Qwen-Image-2512-Lightning) has been updated with CFG / step-distilled LoRA. Usage examples can be found [here](https://github.com/ModelTC/LightX2V/tree/main/examples/qwen_image).
- **December 27, 2025:** 🚀 Supported deployment on MThreads MUSA.
- **December 25, 2025:** 🚀 Supported deployment on AMD ROCm and Ascend 910B.
- **December 23, 2025:** 🚀 We support the [Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) image editing model since Day 0. On a single H100 GPU, LightX2V delivers approximately 1.4× speedup. We support for CFG parallelism, Ulysses parallelism, and efficient offloading technologies. Our [HuggingFace](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning) has been updated with CFG / step-distilled LoRA and FP8 weights. Usage examples can be found [here](https://github.com/ModelTC/LightX2V/tree/main/examples/qwen_image). Combined with LightX2V, 4-step CFG / step distillation, and the FP8 model, the maximum acceleration can reach up to approximately 42×. Feel free to try [LightX2V Online Service](https://x2v.light-ai.top/login) with *Image to Image* and *Qwen-Image-Edit-2511* model.
- **December 22, 2025:** 🚀 Added **Wan2.1 NVFP4 quantization-aware 4-step distilled models**; weights are available on HuggingFace: [Wan-NVFP4](https://huggingface.co/lightx2v/Wan-NVFP4).
- **December 15, 2025:** 🚀 Supported deployment on Hygon DCU.
- **December 4, 2025:** 🚀 Supported GGUF format model inference & deployment on Cambricon MLU590/MetaX C500.
- **November 24, 2025:** 🚀 We released 4-step distilled models for HunyuanVideo-1.5! These models enable **ultra-fast 4-step inference** without CFG requirements, achieving approximately **25x speedup** compared to standard 50-step inference. Both base and FP8 quantized versions are now available: [Hy1.5-Distill-Models](https://huggingface.co/lightx2v/Hy1.5-Distill-Models).
- **November 21, 2025:** 🚀 We support the [HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5) video generation model since Day 0. With the same number of GPUs, LightX2V can achieve a speed improvement of over 2 times and supports deployment on GPUs with lower memory (such as the 24GB RTX 4090). It also supports CFG/Ulysses parallelism, efficient offloading, TeaCache/MagCache technologies, and more. We will soon update more models on our [HuggingFace page](https://huggingface.co/lightx2v), including step distillation, VAE distillation, and other related models. Quantized models and lightweight VAE models are now available: [Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models) for quantized inference, and [LightTAE for HunyuanVideo-1.5](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors) for fast VAE decoding. Refer to [this](https://github.com/ModelTC/LightX2V/tree/main/scripts/hunyuan_video_15) for usage tutorials, or check out the [examples directory](https://github.com/ModelTC/LightX2V/tree/main/examples) for code examples.
## 🏆 Performance Benchmarks (Updated on 2025.12.01)
### 📊 Cross-Framework Performance Comparison (H100)
| Framework | GPUs | Step Time | Speedup |
|-----------|---------|---------|---------|
| Diffusers | 1 | 9.77s/it | 1x |
| xDiT | 1 | 8.93s/it | 1.1x |
| FastVideo | 1 | 7.35s/it | 1.3x |
| SGL-Diffusion | 1 | 6.13s/it | 1.6x |
| **LightX2V** | 1 | **5.18s/it** | **1.9x** 🚀 |
| FastVideo | 8 | 2.94s/it | 1x |
| xDiT | 8 | 2.70s/it | 1.1x |
| SGL-Diffusion | 8 | 1.19s/it | 2.5x |
| **LightX2V** | 8 | **0.75s/it** | **3.9x** 🚀 |
### 📊 Cross-Framework Performance Comparison (RTX 4090D)
| Framework | GPUs | Step Time | Speedup |
|-----------|---------|---------|---------|
| Diffusers | 1 | 30.50s/it | 1x |
| FastVideo | 1 | 22.66s/it | 1.3x |
| xDiT | 1 | OOM | OOM |
| SGL-Diffusion | 1 | OOM | OOM |
| **LightX2V** | 1 | **20.26s/it** | **1.5x** 🚀 |
| FastVideo | 8 | 15.48s/it | 1x |
| xDiT | 8 | OOM | OOM |
| SGL-Diffusion | 8 | OOM | OOM |
| **LightX2V** | 8 | **4.75s/it** | **3.3x** 🚀 |
### 📊 LightX2V Performance Comparison
| Framework | GPU | Configuration | Step Time | Speedup |
|-----------|-----|---------------|-----------|---------------|
| **LightX2V** | H100 | 8 GPUs + cfg | 0.75s/it | 1x |
| **LightX2V** | H100 | 8 GPUs + no cfg | 0.39s/it | 1.9x |
| **LightX2V** | H100 | **8 GPUs + no cfg + fp8** | **0.35s/it** | **2.1x** 🚀 |
| **LightX2V** | 4090D | 8 GPUs + cfg | 4.75s/it | 1x |
| **LightX2V** | 4090D | 8 GPUs + no cfg | 3.13s/it | 1.5x |
| **LightX2V** | 4090D | **8 GPUs + no cfg + fp8** | **2.35s/it** | **2.0x** 🚀 |
**Note**: All the above performance data were tested on Wan2.1-I2V-14B-480P(40 steps, 81 frames). In addition, we also provide 4-step distilled models on the [HuggingFace page](https://huggingface.co/lightx2v).
## 💡 Quick Start
For comprehensive usage instructions, please refer to our documentation: **[English Docs](https://lightx2v-en.readthedocs.io/en/latest/) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)**
**We highly recommend using the Docker environment, as it is the simplest and fastest way to set up the environment. For details, please refer to the Quick Start section in the documentation.**
### Installation from Git
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
### Building from Source
```bash
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
uv pip install -v . # pip install -v .
```
### (Optional) Install Attention/Quantize Operators
For attention operators installation, please refer to our documentation: **[English Docs](https://lightx2v-en.readthedocs.io/en/latest/getting_started/quickstart.html#step-4-install-attention-operators) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/quickstart.html#id9)**
### Usage Example
```python
# examples/wan/wan_i2v.py
"""
Wan2.2 image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 I2V task
# For wan2.1, use model_cls="wan2.1"
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="configs/wan22/wan_moe_i2v.json"
# )
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For Wan models, supports both "block" and "phase"
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
sample_shift=5.0,
)
# Generation parameters
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
image_path="/path/to/img_0.jpg"
save_result_path = "/path/to/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
```
**NVFP4 (quantization-aware 4-step) resources**
- Inference examples: `examples/wan/wan_i2v_nvfp4.py` (I2V) and `examples/wan/wan_t2v_nvfp4.py` (T2V).
- NVFP4 operator build/install guide: see `lightx2v_kernel/README.md`.
> 💡 **More Examples**: For more usage examples including quantization, offloading, caching, and other advanced configurations, please refer to the [examples directory](https://github.com/ModelTC/LightX2V/tree/main/examples).
## 🤖 Supported Model Ecosystem
### Official Open-Source Models
-[LTX-2](https://huggingface.co/Lightricks/LTX-2)
-[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)
-[Wan2.1 & Wan2.2](https://huggingface.co/Wan-AI/)
-[Qwen-Image](https://huggingface.co/Qwen/Qwen-Image)
-[Qwen-Image-Edit](https://huggingface.co/spaces/Qwen/Qwen-Image-Edit)
-[Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)
-[Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511)
### Quantized and Distilled Models/LoRAs (**🚀 Recommended: 4-step inference**)
-[Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
-[Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
-[Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras)
-[Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras)
-[Wan2.1-Distill-NVFP4](https://huggingface.co/lightx2v/Wan-NVFP4)
-[Qwen-Image-Edit-2511-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning)
### Lightweight Autoencoder Models (**🚀 Recommended: fast inference & low memory usage**)
-[Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
### Autoregressive Models
-[Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
-[Self-Forcing](https://github.com/guandeh17/Self-Forcing)
-[Matrix-Game-2.0](https://huggingface.co/Skywork/Matrix-Game-2.0)
🔔 Follow our [HuggingFace page](https://huggingface.co/lightx2v) for the latest model releases from our team.
💡 Refer to the [Model Structure Documentation](https://lightx2v-en.readthedocs.io/en/latest/getting_started/model_structure.html) to quickly get started with LightX2V
## 🚀 Frontend Interfaces
We provide multiple frontend interface deployment options:
- **🎨 Gradio Interface**: Clean and user-friendly web interface, perfect for quick experience and prototyping
- 📖 [Gradio Deployment Guide](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html)
- **🎯 ComfyUI Interface**: Powerful node-based workflow interface, supporting complex video generation tasks
- 📖 [ComfyUI Deployment Guide](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_comfyui.html)
- **🚀 Windows One-Click Deployment**: Convenient deployment solution designed for Windows users, featuring automatic environment configuration and intelligent parameter optimization
- 📖 [Windows One-Click Deployment Guide](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_local_windows.html)
**💡 Recommended Solutions**:
- **First-time Users**: We recommend the Windows one-click deployment solution
- **Advanced Users**: We recommend the ComfyUI interface for more customization options
- **Quick Experience**: The Gradio interface provides the most intuitive operation experience
## 🚀 Core Features
### 🎯 **Ultimate Performance Optimization**
- **🔥 SOTA Inference Speed**: Achieve **~20x** acceleration via step distillation and system optimization (single GPU)
- **⚡️ Revolutionary 4-Step Distillation**: Compress original 40-50 step inference to just 4 steps without CFG requirements
- **🛠️ Advanced Operator Support**: Integrated with cutting-edge operators including [Sage Attention](https://github.com/thu-ml/SageAttention), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Radial Attention](https://github.com/mit-han-lab/radial-attention), [q8-kernel](https://github.com/KONAKONA666/q8_kernels), [sgl-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel), [vllm](https://github.com/vllm-project/vllm)
### 💾 **Resource-Efficient Deployment**
- **💡 Breaking Hardware Barriers**: Run 14B models for 480P/720P video generation with only **8GB VRAM + 16GB RAM**
- **🔧 Intelligent Parameter Offloading**: Advanced disk-CPU-GPU three-tier offloading architecture with phase/block-level granular management
- **⚙️ Comprehensive Quantization**: Support for `w8a8-int8`, `w8a8-fp8`, `w4a4-nvfp4` and other quantization strategies
### 🎨 **Rich Feature Ecosystem**
- **📈 Smart Feature Caching**: Intelligent caching mechanisms to eliminate redundant computations
- **🔄 Parallel Inference**: Multi-GPU parallel processing for enhanced performance
- **📱 Flexible Deployment Options**: Support for Gradio, service deployment, ComfyUI and other deployment methods
- **🎛️ Dynamic Resolution Inference**: Adaptive resolution adjustment for optimal generation quality
- **🎞️ Video Frame Interpolation**: RIFE-based frame interpolation for smooth frame rate enhancement
## 📚 Technical Documentation
### 📖 **Method Tutorials**
- [Model Quantization](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/quantization.html) - Comprehensive guide to quantization strategies
- [Feature Caching](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html) - Intelligent caching mechanisms
- [Attention Mechanisms](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/attention.html) - State-of-the-art attention operators
- [Parameter Offloading](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/offload.html) - Three-tier storage architecture
- [Parallel Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/parallel.html) - Multi-GPU acceleration strategies
- [Changing Resolution Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/changing_resolution.html) - U-shaped resolution strategy
- [Step Distillation](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/step_distill.html) - 4-step inference technology
- [Video Frame Interpolation](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/video_frame_interpolation.html) - Base on the RIFE technology
### 🛠️ **Deployment Guides**
- [Low-Resource Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/for_low_resource.html) - Optimized 8GB VRAM solutions
- [Low-Latency Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/for_low_latency.html) - Ultra-fast inference optimization
- [Gradio Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html) - Web interface setup
- [Service Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_service.html) - Production API service deployment
- [Lora Model Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/lora_deploy.html) - Flexible Lora deployment
## 🤝 Acknowledgments
We sincerely thank all the model repositories and research communities that inspired and promoted the development of LightX2V. This framework is built on the collective efforts of the open-source community. It includes but is not limited to:
- [Tencent-Hunyuan](https://github.com/Tencent-Hunyuan)
- [Wan-Video](https://github.com/Wan-Video)
- [Qwen-Image](https://github.com/QwenLM/Qwen-Image)
- [LightLLM](https://github.com/ModelTC/LightLLM)
- [sglang](https://github.com/sgl-project/sglang)
- [vllm](https://github.com/vllm-project/vllm)
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
- [SageAttention](https://github.com/thu-ml/SageAttention)
- [flashinfer](https://github.com/flashinfer-ai/flashinfer)
- [MagiAttention](https://github.com/SandAI-org/MagiAttention)
- [radial-attention](https://github.com/mit-han-lab/radial-attention)
- [xDiT](https://github.com/xdit-project/xDiT)
- [FastVideo](https://github.com/hao-ai-lab/FastVideo)
## 🌟 Star History
[![Star History Chart](https://api.star-history.com/svg?repos=ModelTC/lightx2v&type=Timeline)](https://star-history.com/#ModelTC/lightx2v&Timeline)
## ✏️ Citation
If you find LightX2V useful in your research, please consider citing our work:
```bibtex
@misc{lightx2v,
author = {LightX2V Contributors},
title = {LightX2V: Light Video Generation Inference Framework},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ModelTC/lightx2v}},
}
```
## 📞 Contact & Support
For questions, suggestions, or support, please feel free to reach out through:
- 🐛 [GitHub Issues](https://github.com/ModelTC/lightx2v/issues) - Bug reports and feature requests
---
<div align="center">
Built with ❤️ by the LightX2V team
</div>
<div align="center" style="font-family: charter;">
<h1>⚡️ LightX2V:<br> 轻量级视频生成推理框架</h1>
<img alt="logo" src="assets/img_lightx2v.png" width=75%></img>
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/ModelTC/lightx2v)
[![Doc](https://img.shields.io/badge/docs-English-99cc2)](https://lightx2v-en.readthedocs.io/en/latest)
[![Doc](https://img.shields.io/badge/文档-中文-99cc2)](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest)
[![Papers](https://img.shields.io/badge/论文集-中文-99cc2)](https://lightx2v-papers-zhcn.readthedocs.io/zh-cn/latest)
[![Docker](https://img.shields.io/badge/Docker-2496ED?style=flat&logo=docker&logoColor=white)](https://hub.docker.com/r/lightx2v/lightx2v/tags)
**\[ [English](README.md) | 中文 \]**
</div>
--------------------------------------------------------------------------------
**LightX2V** 是一个先进的轻量级图像视频生成推理框架,专为提供高效、高性能的图像视频生成解决方案而设计。该统一平台集成了多种前沿的图像视频生成技术,支持文本生成视频(T2V)和图像生成视频(I2V),文本生图片(T2I),图像编辑(I2I)等多样化生成任务。**X2V 表示将不同的输入模态(X,如文本或图像)转换为视觉输出(Vision)**
> 🌐 **立即在线体验!** 无需安装即可体验 LightX2V:**[LightX2V 在线服务](https://x2v.light-ai.top/login)** - 免费、轻量、快速的AI数字人视频生成平台。
> 👋 **加入微信交流群,LightX2V加群机器人微信号: random42seed**
## 🧾 社区代码贡献指南
在提交之前,请确保代码格式符合项目规范。可以使用如下执行命令,确保项目代码格式的一致性。
```bash
pip install ruff pre-commit
pre-commit run --all-files
```
除了LightX2V团队的贡献,我们也收到一些社区开发者的贡献,包括但不限于:
- [triple-Mu](https://github.com/triple-Mu)
- [vivienfanghuagood](https://github.com/vivienfanghuagood)
- [yeahdongcn](https://github.com/yeahdongcn)
- [kikidouloveme79](https://github.com/kikidouloveme79)
## :fire: 最新动态
- **2026年1月20日:** 🚀 我们支持了[LTX-2](https://huggingface.co/Lightricks/LTX-2)音频-视频生成模型,包含CFG并行、block级别offload、FP8 per-tensor量化等先进特性。使用示例可参考[examples/ltx2](https://github.com/ModelTC/LightX2V/tree/main/examples/ltx2)[scripts/ltx2](https://github.com/ModelTC/LightX2V/tree/main/scripts/ltx2)
- **2026年1月6日:** 🚀 我们更新了[Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512)[Qwen/Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511)的8步的CFG/步数蒸馏模型。可以在[Qwen-Image-Edit-2511-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning)[Qwen-Image-2512-Lightning](https://huggingface.co/lightx2v/Qwen-Image-2512-Lightning)下载对应的权重进行使用。使用教程参考[这里](https://github.com/ModelTC/LightX2V/tree/main/examples/qwen_image)
- **2026年1月6日:** 🚀 支持燧原 Enflame S60 (GCU) 的部署。
- **2025年12月31日:** 🚀 我们Day0支持了[Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512) 文生图模型. 我们的[HuggingFace](https://huggingface.co/lightx2v/Qwen-Image-2512-Lightning) 已经更新了CFG/步数蒸馏lora权重。使用方式可以参考[这里](https://github.com/ModelTC/LightX2V/tree/main/examples/qwen_image)
- **2025年12月27日:** 🚀 支持摩尔线程 MUSA 的部署。
- **2025年12月25日:** 🚀 支持 AMD ROCm 和 Ascend 910B 的部署。
- **2025年12月23日:** 🚀 我们Day0支持了[Qwen/Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511)的图像编辑模型,H100单卡,LightX2V可带来约1.4倍的速度提升,支持CFG并行/Ulysses并行,高效Offload等技术。我们的[HuggingFace](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning)已经更新了CFG/步数蒸馏lora和FP8权重。使用方式可以参考[这里](https://github.com/ModelTC/LightX2V/tree/main/examples/qwen_image)。结合LightX2V,4步CFG/步数蒸馏,FP8模型,最高可以加速约42倍。可以在[LightX2V 在线服务](https://x2v.light-ai.top/login)的图生图的Qwen-Image-Edit-2511进行体验。
- **2025年12月22日:** 🚀 新增 **Wan2.1 NVFP4 量化感知 4 步蒸馏模型** 支持;模型与权重已发布在 HuggingFace: [Wan-NVFP4](https://huggingface.co/lightx2v/Wan-NVFP4)
- **2025年12月15日:** 🚀 支持 海光DCU 硬件上的部署。
- **2025年12月4日:** 🚀 支持 GGUF 格式模型推理,以及在寒武纪 MLU590、MetaX C500 硬件上的部署。
- **2025年11月24日:** 🚀 我们发布了HunyuanVideo-1.5的4步蒸馏模型!这些模型支持**超快速4步推理**,无需CFG配置,相比标准50步推理可实现约**25倍加速**。现已提供基础版本和FP8量化版本:[Hy1.5-Distill-Models](https://huggingface.co/lightx2v/Hy1.5-Distill-Models)
- **2025年11月21日:** 🚀 我们Day0支持了[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)的视频生成模型,同样GPU数量,LightX2V可带来约2倍以上的速度提升,并支持更低显存GPU部署(如24G RTX4090)。支持CFG并行/Ulysses并行,高效Offload,TeaCache/MagCache等技术。同时支持沐曦,寒武纪等国产芯片部署。我们很快将在我们的[HuggingFace主页](https://huggingface.co/lightx2v)更新更多模型,包括步数蒸馏,VAE蒸馏等相关模型。量化模型和轻量VAE模型现已可用:[Hy1.5-Quantized-Models](https://huggingface.co/lightx2v/Hy1.5-Quantized-Models)用于量化推理,[HunyuanVideo-1.5轻量TAE](https://huggingface.co/lightx2v/Autoencoders/blob/main/lighttaehy1_5.safetensors)用于快速VAE解码。使用教程参考[这里](https://github.com/ModelTC/LightX2V/tree/main/scripts/hunyuan_video_15),或查看[示例目录](https://github.com/ModelTC/LightX2V/tree/main/examples)获取代码示例。
## 🏆 性能测试数据 (更新于 2025.12.01)
### 📊 推理框架之间性能对比 (H100)
| Framework | GPUs | Step Time | Speedup |
|-----------|---------|---------|---------|
| Diffusers | 1 | 9.77s/it | 1x |
| xDiT | 1 | 8.93s/it | 1.1x |
| FastVideo | 1 | 7.35s/it | 1.3x |
| SGL-Diffusion | 1 | 6.13s/it | 1.6x |
| **LightX2V** | 1 | **5.18s/it** | **1.9x** 🚀 |
| FastVideo | 8 | 2.94s/it | 1x |
| xDiT | 8 | 2.70s/it | 1.1x |
| SGL-Diffusion | 8 | 1.19s/it | 2.5x |
| **LightX2V** | 8 | **0.75s/it** | **3.9x** 🚀 |
### 📊 推理框架之间性能对比 (RTX 4090D)
| Framework | GPUs | Step Time | Speedup |
|-----------|---------|---------|---------|
| Diffusers | 1 | 30.50s/it | 1x |
| FastVideo | 1 | 22.66s/it | 1.3x |
| xDiT | 1 | OOM | OOM |
| SGL-Diffusion | 1 | OOM | OOM |
| **LightX2V** | 1 | **20.26s/it** | **1.5x** 🚀 |
| FastVideo | 8 | 15.48s/it | 1x |
| xDiT | 8 | OOM | OOM |
| SGL-Diffusion | 8 | OOM | OOM |
| **LightX2V** | 8 | **4.75s/it** | **3.3x** 🚀 |
### 📊 LightX2V不同配置之间性能对比
| Framework | GPU | Configuration | Step Time | Speedup |
|-----------|-----|---------------|-----------|---------------|
| **LightX2V** | H100 | 8 GPUs + cfg | 0.75s/it | 1x |
| **LightX2V** | H100 | 8 GPUs + no cfg | 0.39s/it | 1.9x |
| **LightX2V** | H100 | **8 GPUs + no cfg + fp8** | **0.35s/it** | **2.1x** 🚀 |
| **LightX2V** | 4090D | 8 GPUs + cfg | 4.75s/it | 1x |
| **LightX2V** | 4090D | 8 GPUs + no cfg | 3.13s/it | 1.5x |
| **LightX2V** | 4090D | **8 GPUs + no cfg + fp8** | **2.35s/it** | **2.0x** 🚀 |
**注意**: 所有以上性能数据均在 Wan2.1-I2V-14B-480P(40 steps, 81 frames) 上测试。此外,我们[HuggingFace 主页](https://huggingface.co/lightx2v)还提供了4步蒸馏模型。
## 💡 快速开始
详细使用说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/)**
**我们强烈推荐使用 Docker 环境,这是最简单快捷的环境安装方式。具体参考:文档中的快速入门章节。**
### 从 Git 安装
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
### 从源码构建
```bash
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
uv pip install -v . # pip install -v .
```
### (可选)安装注意力/量化算子
注意力算子安装说明请参考我们的文档:**[英文文档](https://lightx2v-en.readthedocs.io/en/latest/getting_started/quickstart.html#step-4-install-attention-operators) | [中文文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/quickstart.html#id9)**
### 使用示例
```python
# examples/wan/wan_i2v.py
"""
Wan2.2 image-to-video generation example.
This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
"""
from lightx2v import LightX2VPipeline
# Initialize pipeline for Wan2.2 I2V task
# For wan2.1, use model_cls="wan2.1"
pipe = LightX2VPipeline(
model_path="/path/to/Wan2.2-I2V-A14B",
model_cls="wan2.2_moe",
task="i2v",
)
# Alternative: create generator from config JSON file
# pipe.create_generator(
# config_json="configs/wan22/wan_moe_i2v.json"
# )
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
pipe.enable_offload(
cpu_offload=True,
offload_granularity="block", # For Wan models, supports both "block" and "phase"
text_encoder_offload=True,
image_encoder_offload=False,
vae_offload=False,
)
# Create generator manually with specified parameters
pipe.create_generator(
attn_mode="sage_attn2",
infer_steps=40,
height=480, # Can be set to 720 for higher resolution
width=832, # Can be set to 1280 for higher resolution
num_frames=81,
guidance_scale=[3.5, 3.5], # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
sample_shift=5.0,
)
# Generation parameters
seed = 42
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
image_path="/path/to/img_0.jpg"
save_result_path = "/path/to/save_results/output.mp4"
# Generate video
pipe.generate(
seed=seed,
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
save_result_path=save_result_path,
)
```
**NVFP4(量化感知 4 步)资源**
- 推理示例:`examples/wan/wan_i2v_nvfp4.py`(I2V),`examples/wan/wan_t2v_nvfp4.py`(T2V)。
- NVFP4 算子编译/安装指南:参见 `lightx2v_kernel/README.md`
> 💡 **更多示例**: 更多使用案例,包括量化、卸载、缓存等进阶配置,请参考 [examples 目录](https://github.com/ModelTC/LightX2V/tree/main/examples)。
## 🤖 支持的模型生态
### 官方开源模型
-[LTX-2](https://huggingface.co/Lightricks/LTX-2)
-[HunyuanVideo-1.5](https://huggingface.co/tencent/HunyuanVideo-1.5)
-[Wan2.1 & Wan2.2](https://huggingface.co/Wan-AI/)
-[Qwen-Image](https://huggingface.co/Qwen/Qwen-Image)
-[Qwen-Image-Edit](https://huggingface.co/spaces/Qwen/Qwen-Image-Edit)
-[Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509)
-[Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511)
### 量化模型和蒸馏模型/Lora (**🚀 推荐:4步推理**)
-[Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
-[Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
-[Wan2.1-Distill-Loras](https://huggingface.co/lightx2v/Wan2.1-Distill-Loras)
-[Wan2.2-Distill-Loras](https://huggingface.co/lightx2v/Wan2.2-Distill-Loras)
-[Wan2.1-Distill-NVFP4](https://huggingface.co/lightx2v/Wan-NVFP4)
-[Qwen-Image-Edit-2511-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning)
### 轻量级自编码器模型(**🚀 推荐:推理快速 + 内存占用低**)
-[Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
### 自回归模型
-[Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
-[Self-Forcing](https://github.com/guandeh17/Self-Forcing)
-[Matrix-Game-2.0](https://huggingface.co/Skywork/Matrix-Game-2.0)
🔔 可以关注我们的[HuggingFace主页](https://huggingface.co/lightx2v),及时获取我们团队的模型。
💡 参考[模型结构文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/getting_started/model_structure.html)快速上手 LightX2V
## 🚀 前端展示
我们提供了多种前端界面部署方式:
- **🎨 Gradio界面**: 简洁易用的Web界面,适合快速体验和原型开发
- 📖 [Gradio部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
- **🎯 ComfyUI界面**: 强大的节点式工作流界面,支持复杂的视频生成任务
- 📖 [ComfyUI部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_comfyui.html)
- **🚀 Windows一键部署**: 专为Windows用户设计的便捷部署方案,支持自动环境配置和智能参数优化
- 📖 [Windows一键部署文档](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)
**💡 推荐方案**:
- **首次使用**: 建议选择Windows一键部署方案
- **高级用户**: 推荐使用ComfyUI界面获得更多自定义选项
- **快速体验**: Gradio界面提供最直观的操作体验
## 🚀 核心特性
### 🎯 **极致性能优化**
- **🔥 SOTA推理速度**: 通过步数蒸馏和系统优化实现**20倍**极速加速(单GPU)
- **⚡️ 革命性4步蒸馏**: 将原始40-50步推理压缩至仅需4步,且无需CFG配置
- **🛠️ 先进算子支持**: 集成顶尖算子,包括[Sage Attention](https://github.com/thu-ml/SageAttention)[Flash Attention](https://github.com/Dao-AILab/flash-attention)[Radial Attention](https://github.com/mit-han-lab/radial-attention)[q8-kernel](https://github.com/KONAKONA666/q8_kernels)[sgl-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)[vllm](https://github.com/vllm-project/vllm)
### 💾 **资源高效部署**
- **💡 突破硬件限制**: **仅需8GB显存 + 16GB内存**即可运行14B模型生成480P/720P视频
- **🔧 智能参数卸载**: 先进的磁盘-CPU-GPU三级卸载架构,支持阶段/块级别的精细化管理
- **⚙️ 全面量化支持**: 支持`w8a8-int8``w8a8-fp8``w4a4-nvfp4`等多种量化策略
### 🎨 **丰富功能生态**
- **📈 智能特征缓存**: 智能缓存机制,消除冗余计算,提升效率
- **🔄 并行推理加速**: 多GPU并行处理,显著提升性能表现
- **📱 灵活部署选择**: 支持Gradio、服务化部署、ComfyUI等多种部署方式
- **🎛️ 动态分辨率推理**: 自适应分辨率调整,优化生成质量
- **🎞️ 视频帧插值**: 基于RIFE的帧插值技术,实现流畅的帧率提升
## 📚 技术文档
### 📖 **方法教程**
- [模型量化](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/quantization.html) - 量化策略全面指南
- [特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html) - 智能缓存机制详解
- [注意力机制](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/attention.html) - 前沿注意力算子
- [参数卸载](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/offload.html) - 三级存储架构
- [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) - 多GPU加速策略
- [变分辨率推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/changing_resolution.html) - U型分辨率策略
- [步数蒸馏](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html) - 4步推理技术
- [视频帧插值](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/video_frame_interpolation.html) - 基于RIFE的帧插值技术
### 🛠️ **部署指南**
- [低资源场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_resource.html) - 优化的8GB显存解决方案
- [低延迟场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_latency.html) - 极速推理优化
- [Gradio部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html) - Web界面搭建
- [服务化部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html) - 生产级API服务部署
- [Lora模型部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/lora_deploy.html) - Lora灵活部署
## 🤝 致谢
我们向所有启发和促进LightX2V开发的模型仓库和研究社区表示诚挚的感谢。此框架基于开源社区的集体努力而构建。包括但不限于:
- [Tencent-Hunyuan](https://github.com/Tencent-Hunyuan)
- [Wan-Video](https://github.com/Wan-Video)
- [Qwen-Image](https://github.com/QwenLM/Qwen-Image)
- [LightLLM](https://github.com/ModelTC/LightLLM)
- [sglang](https://github.com/sgl-project/sglang)
- [vllm](https://github.com/vllm-project/vllm)
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
- [SageAttention](https://github.com/thu-ml/SageAttention)
- [flashinfer](https://github.com/flashinfer-ai/flashinfer)
- [MagiAttention](https://github.com/SandAI-org/MagiAttention)
- [radial-attention](https://github.com/mit-han-lab/radial-attention)
- [xDiT](https://github.com/xdit-project/xDiT)
- [FastVideo](https://github.com/hao-ai-lab/FastVideo)
## 🌟 Star 历史
[![Star History Chart](https://api.star-history.com/svg?repos=ModelTC/lightx2v&type=Timeline)](https://star-history.com/#ModelTC/lightx2v&Timeline)
## ✏️ 引用
如果您发现LightX2V对您的研究有用,请考虑引用我们的工作:
```bibtex
@misc{lightx2v,
author = {LightX2V Contributors},
title = {LightX2V: Light Video Generation Inference Framework},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ModelTC/lightx2v}},
}
```
## 📞 联系与支持
如有任何问题、建议或需要支持,欢迎通过以下方式联系我们:
- 🐛 [GitHub Issues](https://github.com/ModelTC/lightx2v/issues) - 错误报告和功能请求
---
<div align="center">
由 LightX2V 团队用 ❤️ 构建
</div>
# Gradio Demo
Please refer our gradio deployment doc:
[English doc: Gradio Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_gradio.html)
[中文文档: Gradio 部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html)
## 🚀 Quick Start (快速开始)
For Windows users, we provide a convenient one-click deployment solution with automatic environment configuration and intelligent parameter optimization. Please refer to the [One-Click Gradio Launch](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/deploy_local_windows.html) section for detailed instructions.
对于Windows用户,我们提供了便捷的一键部署方式,支持自动环境配置和智能参数优化。详细操作请参考[一键启动Gradio](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_local_windows.html)章节。
"""
重构后的 Gradio Demo 主入口文件
整合了所有模块,支持中英文切换
"""
import argparse
import gc
import json
import logging
import os
import warnings
import torch
from loguru import logger
from utils.i18n import DEFAULT_LANG, set_language
from utils.model_utils import cleanup_memory, extract_op_name, get_model_configs
from utils.ui_builder import build_ui, generate_unique_filename, get_auto_config_dict
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.utils.set_config import get_default_config
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub.utils")
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
logger.add(
"inference_logs.log",
rotation="100 MB",
encoding="utf-8",
enqueue=True,
backtrace=True,
diagnose=True,
)
global_runner = None
current_config = None
cur_dit_path = None
def run_inference(
prompt="",
negative_prompt="",
save_result_path="",
infer_steps=4,
num_frames=81,
resolution="480p",
seed=42,
sample_shift=5,
cfg_scale=1,
fps=16,
model_path_input=None,
model_type_input="wan2.1",
task_type_input="i2v",
dit_path_input=None,
high_noise_path_input=None,
low_noise_path_input=None,
t5_path_input=None,
clip_path_input="",
image_path=None,
vae_path=None,
qwen_image_dit_path_input=None,
qwen_image_vae_path_input=None,
qwen_image_scheduler_path_input=None,
qwen25vl_encoder_path_input=None,
z_image_dit_path_input=None,
z_image_vae_path_input=None,
z_image_scheduler_path_input=None,
qwen3_encoder_path_input=None,
aspect_ratio="1:1",
use_lora=None,
lora_path=None,
lora_strength=None,
high_noise_lora_path=None,
low_noise_lora_path=None,
high_noise_lora_strength=None,
low_noise_lora_strength=None,
):
cleanup_memory()
auto_config = get_auto_config_dict(model_type=model_type_input, resolution=resolution, num_frames=num_frames, task_type=task_type_input)
# 从 auto_config 中获取 offload 和 rope 相关配置
rope_chunk = auto_config["rope_chunk_val"]
rope_chunk_size = auto_config["rope_chunk_size_val"]
cpu_offload = auto_config["cpu_offload_val"]
offload_granularity = auto_config["offload_granularity_val"]
lazy_load = auto_config["lazy_load_val"]
t5_cpu_offload = auto_config["t5_cpu_offload_val"]
clip_cpu_offload = auto_config["clip_cpu_offload_val"]
vae_cpu_offload = auto_config["vae_cpu_offload_val"]
unload_modules = auto_config["unload_modules_val"]
attention_type = auto_config["attention_type_val"]
quant_op = auto_config["quant_op_val"]
use_tiling_vae = auto_config["use_tiling_vae_val"]
clean_cuda_cache = auto_config["clean_cuda_cache_val"]
quant_op = extract_op_name(quant_op)
attention_type = extract_op_name(attention_type)
task = task_type_input
is_image_output = task in ["i2i", "t2i"]
save_result_path = generate_unique_filename(output_dir, is_image=is_image_output)
if cfg_scale == 1:
enable_cfg = False
else:
enable_cfg = True
model_config = get_model_configs(
model_type_input,
model_path_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path,
qwen_image_dit_path_input,
qwen_image_vae_path_input,
qwen_image_scheduler_path_input,
qwen25vl_encoder_path_input,
z_image_dit_path_input,
z_image_vae_path_input,
z_image_scheduler_path_input,
qwen3_encoder_path_input,
quant_op,
use_lora=use_lora,
lora_path=lora_path,
lora_strength=lora_strength,
high_noise_lora_path=high_noise_lora_path,
low_noise_lora_path=low_noise_lora_path,
high_noise_lora_strength=high_noise_lora_strength,
low_noise_lora_strength=low_noise_lora_strength,
)
model_cls = model_config["model_cls"]
model_path = model_config["model_path"]
global global_runner, current_config, cur_dit_path
logger.info(f"Auto-determined model_cls: {model_cls} (model type: {model_type_input})")
if model_cls.startswith("wan2.2"):
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
else:
current_dit_path = dit_path_input
needs_reinit = lazy_load or unload_modules or global_runner is None or cur_dit_path != current_dit_path
config_graio = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
"resolution": resolution,
"resize_mode": "adaptive",
"self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"attn_type": attention_type,
"enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale,
"sample_shift": sample_shift,
"fps": fps,
"feature_caching": "NoCaching",
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": ("phase" if "wan2.2" in model_cls else offload_granularity),
"t5_cpu_offload": t5_cpu_offload,
"clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"use_tiling_vae": use_tiling_vae,
"lazy_load": lazy_load,
"rope_chunk": rope_chunk,
"rope_chunk_size": rope_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "torch",
"t5_lazy_load": lazy_load,
"bucket_shape": {
"0.667": [[480, 832], [544, 960], [720, 960]],
"1.500": [[832, 480], [960, 544], [960, 720]],
"1.000": [[480, 480], [576, 576], [720, 720]],
},
"aspect_ratio": aspect_ratio,
}
args = argparse.Namespace(
model_cls=model_cls,
seed=seed,
task=task,
model_path=model_path,
prompt_enhancer=None,
prompt=prompt,
negative_prompt=negative_prompt,
image_path=image_path,
save_result_path=save_result_path,
return_result_tensor=False,
aspect_ratio=aspect_ratio,
target_shape=[],
)
input_info = init_empty_input_info(args.task)
config = get_default_config()
config.update({k: v for k, v in vars(args).items()})
config.update(config_graio)
config.update(model_config)
# 如果 model_config 中包含 lora_configs,设置 lora_dynamic_apply
if config.get("lora_configs"):
config["lora_dynamic_apply"] = True
logger.info(f"Using model: {model_path}")
logger.info(f"Inference config:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
# 初始化或重用 runner
runner = global_runner
if needs_reinit:
if runner is not None:
del runner
torch.cuda.empty_cache()
gc.collect()
from lightx2v.infer import init_runner
runner = init_runner(config)
data = args.__dict__
update_input_info_from_dict(input_info, data)
current_config = config
cur_dit_path = current_dit_path
if not lazy_load:
global_runner = runner
else:
runner.config = config
data = args.__dict__
update_input_info_from_dict(input_info, data)
runner.run_pipeline(input_info)
cleanup_memory()
return save_result_path
def main(lang=DEFAULT_LANG):
"""主函数"""
set_language(lang)
demo = build_ui(
model_path=model_path,
output_dir=output_dir,
run_inference=run_inference,
lang=lang,
)
# 启动 Gradio 应用
demo.launch(
share=True,
server_port=args.server_port,
server_name=args.server_name,
inbrowser=True,
allowed_paths=[output_dir],
max_file_size="1gb",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="轻量级视频生成")
parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径")
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录")
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="界面语言")
args = parser.parse_args()
global model_path, model_cls, output_dir
model_path = args.model_path
model_cls = "wan2.1"
output_dir = args.output_dir
main(lang=args.lang)
#!/bin/bash
# Lightx2v Gradio Demo Startup Script
# Supports both Image-to-Video (i2v) and Text-to-Video (t2v) modes
# ==================== Configuration Area ====================
# ⚠️ Important: Please modify the following paths according to your actual environment
# 🚨 Storage Performance Tips 🚨
# 💾 Strongly recommend storing model files on SSD solid-state drives!
# 📈 SSD can significantly improve model loading speed and inference performance
# 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
# Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/data/video_gen/lightx2v_debug/LightX2V
# Model path configuration
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
model_path=/models/
# Server configuration
server_name="0.0.0.0"
server_port=8036
# Output directory configuration
output_dir="./outputs"
# GPU configuration
gpu_id=0
# ==================== Environment Variables Setup ====================
export CUDA_VISIBLE_DEVICES=$gpu_id
export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export PROFILING_DEBUG_LEVEL=2
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# ==================== Parameter Parsing ====================
# Default interface language
lang="zh"
# 解析命令行参数
while [[ $# -gt 0 ]]; do
case $1 in
--lang)
lang="$2"
shift 2
;;
--port)
server_port="$2"
shift 2
;;
--gpu)
gpu_id="$2"
export CUDA_VISIBLE_DEVICES=$gpu_id
shift 2
;;
--output_dir)
output_dir="$2"
shift 2
;;
--model_path)
model_path="$2"
shift 2
;;
--help)
echo "🎬 Lightx2v Gradio Demo Startup Script"
echo "=========================================="
echo "Usage: $0 [options]"
echo ""
echo "📋 Available options:"
echo " --lang zh|en Interface language (default: zh)"
echo " zh: Chinese interface"
echo " en: English interface"
echo " --port PORT Server port (default: 8032)"
echo " --gpu GPU_ID GPU device ID (default: 0)"
echo " --model_path PATH Model path (default: configured in script)"
echo " --output_dir DIR Output video save directory (default: ./outputs)"
echo " --help Show this help message"
echo ""
echo "📝 Notes:"
echo " - Task type (i2v/t2v) and model type are selected in the web UI"
echo " - Model class is auto-detected based on selected diffusion model"
echo " - Edit script to configure model paths before first use"
echo " - Ensure required Python dependencies are installed"
echo " - Recommended to use GPU with 8GB+ VRAM"
echo " - 🚨 Strongly recommend storing models on SSD for better performance"
exit 0
;;
*)
echo "Unknown parameter: $1"
echo "Use --help to see help information"
exit 1
;;
esac
done
# ==================== Parameter Validation ====================
if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
echo "Error: Language must be 'zh' or 'en'"
exit 1
fi
# Check if model path exists
if [[ ! -d "$model_path" ]]; then
echo "❌ Error: Model path does not exist"
echo "📁 Path: $model_path"
echo "🔧 Solutions:"
echo " 1. Check model path configuration in script"
echo " 2. Ensure model files are properly downloaded"
echo " 3. Verify path permissions are correct"
echo " 4. 💾 Recommend storing models on SSD for faster loading"
exit 1
fi
# 使用新的统一入口文件
demo_file="gradio_demo.py"
echo "🌏 Using unified interface (language: $lang)"
# Check if demo file exists
if [[ ! -f "$demo_file" ]]; then
echo "❌ Error: Demo file does not exist"
echo "📄 File: $demo_file"
echo "🔧 Solutions:"
echo " 1. Ensure script is run in the correct directory"
echo " 2. Check if file has been renamed or moved"
echo " 3. Re-clone or download project files"
exit 1
fi
# ==================== System Information Display ====================
echo "=========================================="
echo "🚀 Lightx2v Gradio Demo Starting..."
echo "=========================================="
echo "📁 Project path: $lightx2v_path"
echo "🤖 Model path: $model_path"
echo "🌏 Interface language: $lang"
echo "🖥️ GPU device: $gpu_id"
echo "🌐 Server address: $server_name:$server_port"
echo "📁 Output directory: $output_dir"
echo "📝 Note: Task type and model class are selected in web UI"
echo "=========================================="
# Display system resource information
echo "💻 System resource information:"
free -h | grep -E "Mem|Swap"
echo ""
# Display GPU information
if command -v nvidia-smi &> /dev/null; then
echo "🎮 GPU information:"
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits | head -1
echo ""
fi
# ==================== Start Demo ====================
echo "🎬 Starting Gradio demo..."
echo "📱 Please access in browser: http://$server_name:$server_port"
echo "⏹️ Press Ctrl+C to stop service"
echo "🔄 First startup may take several minutes to load resources..."
echo "=========================================="
# Start Python demo
if [[ "$demo_file" == "gradio_demo.py" ]]; then
python $demo_file \
--model_path "$model_path" \
--server_name "$server_name" \
--server_port "$server_port" \
--output_dir "$output_dir" \
--lang "$lang"
else
python $demo_file \
--model_path "$model_path" \
--server_name "$server_name" \
--server_port "$server_port" \
--output_dir "$output_dir"
fi
# Display final system resource usage
echo ""
echo "=========================================="
echo "📊 Final system resource usage:"
free -h | grep -E "Mem|Swap"
@echo off
chcp 65001 >nul
echo 🎬 LightX2V Gradio Windows Startup Script
echo ==========================================
REM ==================== Configuration Area ====================
REM ⚠️ Important: Please modify the following paths according to your actual environment
REM 🚨 Storage Performance Tips 🚨
REM 💾 Strongly recommend storing model files on SSD solid-state drives!
REM 📈 SSD can significantly improve model loading speed and inference performance
REM 🐌 Using mechanical hard drives (HDD) may cause slow model loading and affect overall experience
REM LightX2V project root directory path
REM Example: D:\LightX2V
set lightx2v_path=/path/to/LightX2V
REM Model path configuration
REM Model root directory path
REM Example: D:\models\LightX2V
set model_path=/path/to/LightX2V
REM Server configuration
set server_name=127.0.0.1
set server_port=8032
REM Output directory configuration
set output_dir=./outputs
REM GPU configuration
set gpu_id=0
REM ==================== Environment Variables Setup ====================
set CUDA_VISIBLE_DEVICES=%gpu_id%
set PYTHONPATH=%lightx2v_path%;%PYTHONPATH%
set PROFILING_DEBUG_LEVEL=2
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
REM ==================== Parameter Parsing ====================
REM Default interface language
set lang=zh
REM Parse command line arguments
:parse_args
if "%1"=="" goto :end_parse
if "%1"=="--lang" (
set lang=%2
shift
shift
goto :parse_args
)
if "%1"=="--port" (
set server_port=%2
shift
shift
goto :parse_args
)
if "%1"=="--gpu" (
set gpu_id=%2
set CUDA_VISIBLE_DEVICES=%gpu_id%
shift
shift
goto :parse_args
)
if "%1"=="--output_dir" (
set output_dir=%2
shift
shift
goto :parse_args
)
if "%1"=="--help" (
echo 🎬 LightX2V Gradio Windows Startup Script
echo ==========================================
echo Usage: %0 [options]
echo.
echo 📋 Available options:
echo --lang zh^|en Interface language (default: zh)
echo zh: Chinese interface
echo en: English interface
echo --port PORT Server port (default: 8032)
echo --gpu GPU_ID GPU device ID (default: 0)
echo --output_dir OUTPUT_DIR
echo Output video save directory (default: ./outputs)
echo --help Show this help message
echo.
echo 🚀 Usage examples:
echo %0 # Default startup
echo %0 --lang zh --port 8032 # Start with specified parameters
echo %0 --lang en --port 7860 # English interface
echo %0 --gpu 1 --port 8032 # Use GPU 1
echo %0 --output_dir ./custom_output # Use custom output directory
echo.
echo 📝 Notes:
echo - Edit script to configure model path before first use
echo - Ensure required Python dependencies are installed
echo - Recommended to use GPU with 8GB+ VRAM
echo - 🚨 Strongly recommend storing models on SSD for better performance
pause
exit /b 0
)
echo Unknown parameter: %1
echo Use --help to see help information
pause
exit /b 1
:end_parse
REM ==================== Parameter Validation ====================
if "%lang%"=="zh" goto :valid_lang
if "%lang%"=="en" goto :valid_lang
echo Error: Language must be 'zh' or 'en'
pause
exit /b 1
:valid_lang
REM Check if model path exists
if not exist "%model_path%" (
echoError: Model path does not exist
echo 📁 Path: %model_path%
echo 🔧 Solutions:
echo 1. Check model path configuration in script
echo 2. Ensure model files are properly downloaded
echo 3. Verify path permissions are correct
echo 4. 💾 Recommend storing models on SSD for faster loading
pause
exit /b 1
)
REM Select demo file based on language
if "%lang%"=="zh" (
set demo_file=gradio_demo_zh.py
echo 🌏 Using Chinese interface
) else (
set demo_file=gradio_demo.py
echo 🌏 Using English interface
)
REM Check if demo file exists
if not exist "%demo_file%" (
echoError: Demo file does not exist
echo 📄 File: %demo_file%
echo 🔧 Solutions:
echo 1. Ensure script is run in the correct directory
echo 2. Check if file has been renamed or moved
echo 3. Re-clone or download project files
pause
exit /b 1
)
REM ==================== System Information Display ====================
echo ==========================================
echo 🚀 LightX2V Gradio Starting...
echo ==========================================
echo 📁 Project path: %lightx2v_path%
echo 🤖 Model path: %model_path%
echo 🌏 Interface language: %lang%
echo 🖥️ GPU device: %gpu_id%
echo 🌐 Server address: %server_name%:%server_port%
echo 📁 Output directory: %output_dir%
echo ==========================================
REM Display system resource information
echo 💻 System resource information:
wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table
REM Display GPU information
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits 2>nul
if errorlevel 1 (
echo 🎮 GPU information: Unable to get GPU info
) else (
echo 🎮 GPU information:
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits
)
REM ==================== Start Demo ====================
echo 🎬 Starting Gradio demo...
echo 📱 Please access in browser: http://%server_name%:%server_port%
echo ⏹️ Press Ctrl+C to stop service
echo 🔄 First startup may take several minutes to load resources...
echo ==========================================
REM Start Python demo
python %demo_file% ^
--model_path "%model_path%" ^
--server_name %server_name% ^
--server_port %server_port% ^
--output_dir "%output_dir%"
REM Display final system resource usage
echo.
echo ==========================================
echo 📊 Final system resource usage:
wmic OS get TotalVisibleMemorySize,FreePhysicalMemory /format:table
pause
"""国际化支持模块"""
import os
# 默认语言
DEFAULT_LANG = os.getenv("GRADIO_LANG", "zh")
# 翻译字典
TRANSLATIONS = {
"zh": {
"title": "🎬 LightX2V 图片/视频生成器",
"model_config": "🗂️ 模型配置",
"model_config_hint": "💡 **提示**:请确保以下每个模型选项至少有一个已下载✅的模型可用,否则可能无法正常生成视频。",
"fp8_not_supported": "⚠️ **您的设备不支持fp8推理**,已自动隐藏包含fp8的模型选项。",
"model_type": "模型类型",
"model_type_info": "Wan2.2 需要分别指定高噪模型和低噪模型; Qwen-Image-Edit-2511 用于图片编辑(i2i); Qwen-Image-2512 用于文本生成图片(t2i); Z-Image-Turbo 用于文本生成图片(t2i)",
"qwen3_encoder": "📝 Qwen3 编码器",
"scheduler": "⏱️ 调度器",
"qwen25vl_encoder": "📝 Qwen25-VL 编码器",
"task_type": "任务类型",
"task_type_info": "I2V: 图生视频, T2V: 文生视频, T2I: 文生图, I2I: 图片编辑",
"download_source": "📥 下载源",
"download_source_info": "选择模型下载源",
"diffusion_model": "🎨 Diffusion模型",
"high_noise_model": "🔊 高噪模型",
"low_noise_model": "🔇 低噪模型",
"text_encoder": "📝 文本编码器",
"text_encoder_tokenizer": "📝 文本编码器 Tokenizer",
"image_encoder": "🖼️ 图像编码器",
"image_encoder_tokenizer": "🖼️ 图像编码器 Tokenizer",
"vae": "🎞️ VAE编码/解码器",
"attention_operator": "⚡ 注意力算子",
"attention_operator_info": "使用适当的注意力算子加速推理",
"quant_operator": "⚡矩阵乘法算子",
"quant_operator_info": "选择低精度矩阵乘法算子以加速推理",
"input_params": "📥 输入参数",
"input_image": "输入图像(可拖入多张图片)",
"image_preview": "已上传的图片预览",
"image_path": "图片路径",
"prompt": "提示词",
"prompt_placeholder": "描述视频/图片内容...",
"negative_prompt": "负向提示词",
"negative_prompt_placeholder": "不希望出现在视频/图片中的内容...",
"max_resolution": "最大分辨率",
"max_resolution_info": "如果显存不足,可调低分辨率",
"random_seed": "随机种子",
"infer_steps": "推理步数",
"infer_steps_distill": "蒸馏模型推理步数默认为4。",
"infer_steps_info": "视频生成的推理步数。增加步数可能提高质量但降低速度。",
"sample_shift": "分布偏移",
"sample_shift_info": "控制样本分布偏移的程度。值越大表示偏移越明显。",
"cfg_scale": "CFG缩放因子",
"cfg_scale_info": "控制提示词的影响强度。值越高,提示词的影响越大。当值为1时,自动禁用CFG。",
"enable_cfg": "启用无分类器引导",
"fps": "每秒帧数(FPS)",
"fps_info": "视频的每秒帧数。较高的FPS会产生更流畅的视频。",
"num_frames": "总帧数",
"num_frames_info": "视频中的总帧数。更多帧数会产生更长的视频。",
"video_duration": "视频时长(秒)",
"video_duration_info": "视频的时长(秒)。实际帧数 = 时长 × FPS。",
"output_path": "输出视频路径",
"output_path_info": "必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
"output_image_path": "输出图片路径",
"output_image_path_info": "必须包含.png扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
"output_result": "📤 生成的结果",
"output_image": "输出图片",
"generate_video": "🎬 生成视频",
"generate_image": "🖼️ 生成图片",
"infer_steps_image_info": "图片编辑的推理步数,默认为8。",
"aspect_ratio": "宽高比",
"aspect_ratio_info": "选择生成图片的宽高比",
"model_config_hint_image": "💡 **提示**:请确保以下每个模型选项至少有一个已下载✅的模型可用,否则可能无法正常生成图片。",
"download": "📥 下载",
"downloaded": "✅ 已下载",
"not_downloaded": "❌ 未下载",
"download_complete": "✅ {model_name} 下载完成",
"download_start": "开始从 {source} 下载 {model_name}...",
"please_select_model": "请先选择模型",
"loading_models": "正在加载 Hugging Face 模型列表缓存...",
"models_loaded": "模型列表缓存加载完成",
"use_lora": "使用 LoRA",
"lora": "🎨 LoRA",
"lora_info": "选择要使用的 LoRA 模型",
"lora_strength": "LoRA 强度",
"lora_strength_info": "控制 LoRA 的影响强度,范围 0-10",
"high_noise_lora": "🔊 高噪模型 LoRA",
"high_noise_lora_info": "选择高噪模型使用的 LoRA",
"high_noise_lora_strength": "高噪模型 LoRA 强度",
"high_noise_lora_strength_info": "控制高噪模型 LoRA 的影响强度,范围 0-10",
"low_noise_lora": "🔇 低噪模型 LoRA",
"low_noise_lora_info": "选择低噪模型使用的 LoRA",
"low_noise_lora_strength": "低噪模型 LoRA 强度",
"low_noise_lora_strength_info": "控制低噪模型 LoRA 的影响强度,范围 0-10",
},
"en": {
"title": "🎬 LightX2V Image/Video Generator",
"model_config": "🗂️ Model Configuration",
"model_config_hint": "💡 **Tip**: Please ensure at least one downloaded ✅ model is available for each model option below, otherwise video generation may fail.",
"fp8_not_supported": "⚠️ **Your device does not support fp8 inference**, fp8 model options have been automatically hidden.",
"model_type": "Model Type",
"model_type_info": "Wan2.2 requires separate high-noise and low-noise models; Qwen-Image-Edit-2511 is for image editing (i2i); Qwen-Image-2512 is for text-to-image (t2i); Z-Image-Turbo is for text-to-image (t2i)",
"qwen3_encoder": "📝 Qwen3 Encoder",
"scheduler": "⏱️ Scheduler",
"qwen25vl_encoder": "📝 Qwen25-VL Encoder",
"task_type": "Task Type",
"task_type_info": "I2V: Image-to-Video, T2V: Text-to-Video, T2I: Text-to-Image, I2I: Image Editing",
"download_source": "📥 Download Source",
"download_source_info": "Select model download source",
"diffusion_model": "🎨 Diffusion Model",
"high_noise_model": "🔊 High Noise Model",
"low_noise_model": "🔇 Low Noise Model",
"text_encoder": "📝 Text Encoder",
"text_encoder_tokenizer": "📝 Text Encoder Tokenizer",
"image_encoder": "🖼️ Image Encoder",
"image_encoder_tokenizer": "🖼️ Image Encoder Tokenizer",
"vae": "🎞️ VAE Encoder/Decoder",
"attention_operator": "⚡ Attention Operator",
"attention_operator_info": "Use appropriate attention operator to accelerate inference",
"quant_operator": "⚡ Matrix Multiplication Operator",
"quant_operator_info": "Select low-precision matrix multiplication operator to accelerate inference",
"input_params": "📥 Input Parameters",
"input_image": "Input Image (drag multiple images)",
"image_preview": "Uploaded Image Preview",
"image_path": "Image Path",
"prompt": "Prompt",
"prompt_placeholder": "Describe video/image content...",
"negative_prompt": "Negative Prompt",
"negative_prompt_placeholder": "Content you don't want in the video/image...",
"max_resolution": "Max Resolution",
"max_resolution_info": "Reduce resolution if VRAM is insufficient",
"random_seed": "Random Seed",
"infer_steps": "Inference Steps",
"infer_steps_distill": "Distill model inference steps default to 4.",
"infer_steps_info": "Number of inference steps for video generation. More steps may improve quality but reduce speed.",
"sample_shift": "Sample Shift",
"sample_shift_info": "Control the degree of sample distribution shift. Higher values indicate more obvious shift.",
"cfg_scale": "CFG Scale",
"cfg_scale_info": "Control the influence strength of prompts. Higher values mean stronger prompt influence. When value is 1, CFG is automatically disabled.",
"enable_cfg": "Enable Classifier-Free Guidance",
"fps": "Frames Per Second (FPS)",
"fps_info": "Frames per second of the video. Higher FPS produces smoother videos.",
"num_frames": "Total Frames",
"num_frames_info": "Total number of frames in the video. More frames produce longer videos.",
"video_duration": "Video Duration (seconds)",
"video_duration_info": "Duration of the video in seconds. Actual frames = duration × FPS.",
"output_path": "Output Video Path",
"output_path_info": "Must include .mp4 extension. If left empty or using default value, a unique filename will be automatically generated.",
"output_image_path": "Output Image Path",
"output_image_path_info": "Must include .png extension. If left empty or using default value, a unique filename will be automatically generated.",
"output_result": "📤 Generated Result",
"output_image": "Output Image",
"generate_video": "🎬 Generate Video",
"generate_image": "🖼️ Generate Image",
"infer_steps_image_info": "Number of inference steps for image editing, default is 8.",
"aspect_ratio": "Aspect Ratio",
"aspect_ratio_info": "Select the aspect ratio for generated images",
"model_config_hint_image": "💡 **Tip**: Please ensure at least one downloaded ✅ model is available for each model option below, otherwise image generation may fail.",
"download": "📥 Download",
"downloaded": "✅ Downloaded",
"not_downloaded": "❌ Not Downloaded",
"download_complete": "✅ {model_name} download complete",
"download_start": "Starting to download {model_name} from {source}...",
"please_select_model": "Please select a model first",
"loading_models": "Loading Hugging Face model list cache...",
"models_loaded": "Model list cache loaded",
"use_lora": "Use LoRA",
"lora": "🎨 LoRA",
"lora_info": "Select LoRA model to use",
"lora_strength": "LoRA Strength",
"lora_strength_info": "Control LoRA influence strength, range 0-10",
"high_noise_lora": "🔊 High Noise Model LoRA",
"high_noise_lora_info": "Select high noise model LoRA to use",
"high_noise_lora_strength": "High Noise Model LoRA Strength",
"high_noise_lora_strength_info": "Control high noise model LoRA influence strength, range 0-10",
"low_noise_lora": "🔇 Low Noise Model LoRA",
"low_noise_lora_info": "Select low noise model LoRA to use",
"low_noise_lora_strength": "Low Noise Model LoRA Strength",
"low_noise_lora_strength_info": "Control low noise model LoRA influence strength, range 0-10",
},
}
def t(key: str, lang: str = None) -> str:
"""获取翻译文本"""
if lang is None:
lang = DEFAULT_LANG
if lang not in TRANSLATIONS:
lang = "zh"
return TRANSLATIONS[lang].get(key, key)
def set_language(lang: str):
"""设置语言"""
global DEFAULT_LANG
if lang in TRANSLATIONS:
DEFAULT_LANG = lang
os.environ["GRADIO_LANG"] = lang
import os
import random
from datetime import datetime
import gradio as gr
from utils.i18n import DEFAULT_LANG, t
from utils.model_components import (
build_qwen_image_components,
build_z_image_turbo_components,
)
from utils.model_utils import HF_AVAILABLE, MS_AVAILABLE, is_fp8_supported_gpu
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
"""生成随机种子"""
return random.randint(0, MAX_NUMPY_SEED)
def generate_unique_filename(output_dir, is_image=False):
"""生成唯一文件名"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
extension = ".png" if is_image else ".mp4"
filename = f"{timestamp}{extension}"
return os.path.join(output_dir, filename)
def build_image_page(
model_path,
output_dir,
run_inference,
lang=DEFAULT_LANG,
):
# 主布局:左右分栏
with gr.Row():
# 左侧:配置和输入区域
with gr.Column(scale=5):
# 模型配置区域
with gr.Accordion(t("model_config", lang), open=True, elem_classes=["model-config"]):
gr.Markdown(t("model_config_hint_image", lang))
# FP8 支持提示
if not is_fp8_supported_gpu():
gr.Markdown(t("fp8_not_supported", lang))
model_path_input = gr.Textbox(value=model_path, visible=False)
# 任务类型选择(前端显示大写,传递值小写)
task_type_input = gr.Radio(
label=t("task_type", lang),
choices=[("I2I", "i2i"), ("T2I", "t2i")],
value="i2i",
info=t("task_type_info", lang),
elem_classes=["horizontal-radio"],
)
# 模型类型选择(根据任务类型动态更新)
model_type_input = gr.Radio(
label=t("model_type", lang),
choices=["Qwen-Image-Edit-2511"],
value="Qwen-Image-Edit-2511",
info=t("model_type_info", lang),
elem_classes=["horizontal-radio"],
)
download_source_input = gr.Radio(
label=t("download_source", lang),
choices=(["huggingface", "modelscope"] if (HF_AVAILABLE and MS_AVAILABLE) else (["huggingface"] if HF_AVAILABLE else ["modelscope"] if MS_AVAILABLE else [])),
value=("modelscope" if MS_AVAILABLE else ("huggingface" if HF_AVAILABLE else None)),
info=t("download_source_info", lang),
visible=HF_AVAILABLE or MS_AVAILABLE,
elem_classes=["horizontal-radio"],
)
# 构建 Qwen 模型组件(默认显示)
qwen_image_components = build_qwen_image_components(model_path, model_path_input, download_source_input, model_type_input, lang)
qwen_image_dit_path_input = qwen_image_components["qwen_image_dit_path_input"]
qwen_image_vae_path_input = qwen_image_components["qwen_image_vae_path_input"]
qwen_image_scheduler_path_input = qwen_image_components["qwen_image_scheduler_path_input"]
qwen25vl_encoder_path_input = qwen_image_components["qwen25vl_encoder_path_input"]
qwen_image_dit_download_btn = qwen_image_components["qwen_image_dit_download_btn"]
qwen_image_vae_download_btn = qwen_image_components["qwen_image_vae_download_btn"]
qwen_image_scheduler_download_btn = qwen_image_components["qwen_image_scheduler_download_btn"]
qwen25vl_encoder_download_btn = qwen_image_components["qwen25vl_encoder_download_btn"]
qwen_image_use_lora_input = qwen_image_components["use_lora"]
qwen_image_lora_path_input = qwen_image_components["lora_path_input"]
qwen_image_lora_strength_input = qwen_image_components["lora_strength"]
# 构建 Z-Image-Turbo 模型组件(默认隐藏)
z_image_turbo_components = build_z_image_turbo_components(model_path, model_path_input, download_source_input, model_type_input, lang)
z_image_turbo_dit_path_input = z_image_turbo_components["z_image_turbo_dit_path_input"]
z_image_turbo_vae_path_input = z_image_turbo_components["z_image_turbo_vae_path_input"]
z_image_turbo_scheduler_path_input = z_image_turbo_components["z_image_turbo_scheduler_path_input"]
qwen3_encoder_path_input = z_image_turbo_components["qwen3_encoder_path_input"]
z_image_turbo_dit_download_btn = z_image_turbo_components["z_image_turbo_dit_download_btn"]
z_image_turbo_vae_download_btn = z_image_turbo_components["z_image_turbo_vae_download_btn"]
z_image_turbo_scheduler_download_btn = z_image_turbo_components["z_image_turbo_scheduler_download_btn"]
qwen3_encoder_download_btn = z_image_turbo_components["qwen3_encoder_download_btn"]
z_image_turbo_use_lora_input = z_image_turbo_components["use_lora"]
z_image_turbo_lora_path_input = z_image_turbo_components["lora_path_input"]
z_image_turbo_lora_strength_input = z_image_turbo_components["lora_strength"]
# 获取组件容器以便控制显示/隐藏
qwen_components_group = qwen_image_components.get("components_group", None)
z_image_turbo_components_group = z_image_turbo_components.get("components_group", None)
with gr.Accordion(t("input_params", lang), open=True, elem_classes=["input-params"]):
image_files = gr.File(
label=t("input_image", lang),
file_count="multiple",
file_types=["image"],
height=150,
interactive=True,
)
image_gallery = gr.Gallery(
label=t("image_preview", lang),
columns=4,
rows=2,
height=200,
object_fit="contain",
show_label=True,
)
image_path = gr.Textbox(label=t("image_path", lang), visible=False)
aspect_ratio = gr.Dropdown(
label=t("aspect_ratio", lang),
choices=["1:1", "16:9", "9:16", "4:3", "3:4"],
value="1:1",
visible=True,
info=t("aspect_ratio_info", lang),
)
# 任务类型变化处理函数
def on_task_type_change(task_type_val):
if task_type_val == "t2i":
# t2i 模式:显示 Qwen-Image-2512 和 Z-Image-Turbo,隐藏图片输入和宽高比
return (
gr.update(choices=["Z-Image-Turbo", "Qwen-Image-2512"], value="Z-Image-Turbo"), # model_type_input
gr.update(visible=False), # image_files
gr.update(visible=False), # image_gallery
gr.update(value=""), # image_path
gr.update(visible=False), # aspect_ratio
)
else:
# i2i 模式:显示 Qwen-Image-Edit-2511,显示图片输入和宽高比
return (
gr.update(choices=["Qwen-Image-Edit-2511"], value="Qwen-Image-Edit-2511"), # model_type_input
gr.update(visible=True), # image_files
gr.update(visible=True), # image_gallery
gr.update(), # image_path (保持不变)
gr.update(visible=True), # aspect_ratio
)
# 绑定任务类型变化事件
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[model_type_input, image_files, image_gallery, image_path, aspect_ratio],
)
# 模型类型变化处理函数
def on_model_type_change(model_type_val, model_path_val):
# 控制组件组显示/隐藏
show_qwen = model_type_val in ["Qwen-Image-2512", "Qwen-Image-Edit-2511"]
show_z_image_turbo = model_type_val == "Z-Image-Turbo"
# 导入更新函数
from utils.model_handlers import update_model_status
from utils.model_utils import extract_model_name
# 更新模型选择
if model_type_val == "Qwen-Image-2512":
from utils.model_choices import get_qwen_image_2512_dit_choices, get_qwen_image_2512_scheduler_choices, get_qwen_image_2512_vae_choices
dit_choices = get_qwen_image_2512_dit_choices(model_path_val)
vae_choices = get_qwen_image_2512_vae_choices(model_path_val)
scheduler_choices = get_qwen_image_2512_scheduler_choices(model_path_val)
# 更新下载按钮状态
from utils.model_choices import get_qwen25vl_encoder_choices
qwen25vl_encoder_choices = get_qwen25vl_encoder_choices(model_path_val)
dit_btn_update = update_model_status(model_path_val, extract_model_name(dit_choices[0]) if dit_choices else "", "qwen_image_2512_dit")
vae_btn_update = update_model_status(model_path_val, extract_model_name(vae_choices[0]) if vae_choices else "", "qwen_image_2512_vae")
scheduler_btn_update = update_model_status(model_path_val, extract_model_name(scheduler_choices[0]) if scheduler_choices else "", "qwen_image_2512_scheduler")
qwen25vl_btn_update = update_model_status(model_path_val, extract_model_name(qwen25vl_encoder_choices[0]) if qwen25vl_encoder_choices else "", "qwen25vl_encoder")
return (
gr.update(visible=show_qwen), # qwen_components_group
gr.update(visible=show_z_image_turbo), # z_image_turbo_components_group
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), # qwen_image_dit_path_input
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), # qwen_image_vae_path_input
gr.update(choices=scheduler_choices, value=scheduler_choices[0] if scheduler_choices else ""), # qwen_image_scheduler_path_input
gr.update(), # z_image_turbo_dit_path_input (保持不变)
gr.update(), # z_image_turbo_vae_path_input (保持不变)
gr.update(), # z_image_turbo_scheduler_path_input (保持不变)
gr.update(), # qwen3_encoder_path_input (保持不变)
gr.update(), # aspect_ratio (保持不变)
dit_btn_update, # qwen_image_dit_download_btn
vae_btn_update, # qwen_image_vae_download_btn
scheduler_btn_update, # qwen_image_scheduler_download_btn
qwen25vl_btn_update, # qwen25vl_encoder_download_btn
gr.update(), # z_image_turbo_dit_download_btn
gr.update(), # z_image_turbo_vae_download_btn
gr.update(), # z_image_turbo_scheduler_download_btn
gr.update(), # qwen3_encoder_download_btn
)
elif model_type_val == "Z-Image-Turbo":
from utils.model_choices import get_qwen3_encoder_choices, get_z_image_turbo_dit_choices, get_z_image_turbo_scheduler_choices, get_z_image_turbo_vae_choices
dit_choices = get_z_image_turbo_dit_choices(model_path_val)
vae_choices = get_z_image_turbo_vae_choices(model_path_val)
scheduler_choices = get_z_image_turbo_scheduler_choices(model_path_val)
qwen3_encoder_choices = get_qwen3_encoder_choices(model_path_val)
# 更新下载按钮状态
dit_btn_update = update_model_status(model_path_val, extract_model_name(dit_choices[0]) if dit_choices else "", "z_image_turbo_dit")
vae_btn_update = update_model_status(model_path_val, extract_model_name(vae_choices[0]) if vae_choices else "", "z_image_turbo_vae")
scheduler_btn_update = update_model_status(model_path_val, extract_model_name(scheduler_choices[0]) if scheduler_choices else "", "z_image_turbo_scheduler")
qwen3_btn_update = update_model_status(model_path_val, extract_model_name(qwen3_encoder_choices[0]) if qwen3_encoder_choices else "", "qwen3_encoder")
return (
gr.update(visible=show_qwen), # qwen_components_group
gr.update(visible=show_z_image_turbo), # z_image_turbo_components_group
gr.update(), # qwen_image_dit_path_input (保持不变)
gr.update(), # qwen_image_vae_path_input (保持不变)
gr.update(), # qwen_image_scheduler_path_input (保持不变)
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), # z_image_turbo_dit_path_input
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), # z_image_turbo_vae_path_input
gr.update(choices=scheduler_choices, value=scheduler_choices[0] if scheduler_choices else ""), # z_image_turbo_scheduler_path_input
gr.update(choices=qwen3_encoder_choices, value=qwen3_encoder_choices[0] if qwen3_encoder_choices else ""), # qwen3_encoder_path_input
gr.update(visible=True), # aspect_ratio
gr.update(), # qwen_image_dit_download_btn
gr.update(), # qwen_image_vae_download_btn
gr.update(), # qwen_image_scheduler_download_btn
gr.update(), # qwen25vl_encoder_download_btn
dit_btn_update, # z_image_turbo_dit_download_btn
vae_btn_update, # z_image_turbo_vae_download_btn
scheduler_btn_update, # z_image_turbo_scheduler_download_btn
qwen3_btn_update, # qwen3_encoder_download_btn
)
else: # Qwen-Image-Edit-2511
from utils.model_choices import get_qwen_image_dit_choices, get_qwen_image_scheduler_choices, get_qwen_image_vae_choices
dit_choices = get_qwen_image_dit_choices(model_path_val)
vae_choices = get_qwen_image_vae_choices(model_path_val)
scheduler_choices = get_qwen_image_scheduler_choices(model_path_val)
# 更新下载按钮状态
from utils.model_choices import get_qwen25vl_encoder_choices
qwen25vl_encoder_choices = get_qwen25vl_encoder_choices(model_path_val)
dit_btn_update = update_model_status(model_path_val, extract_model_name(dit_choices[0]) if dit_choices else "", "qwen_image_dit")
vae_btn_update = update_model_status(model_path_val, extract_model_name(vae_choices[0]) if vae_choices else "", "qwen_image_vae")
scheduler_btn_update = update_model_status(model_path_val, extract_model_name(scheduler_choices[0]) if scheduler_choices else "", "qwen_image_scheduler")
qwen25vl_btn_update = update_model_status(model_path_val, extract_model_name(qwen25vl_encoder_choices[0]) if qwen25vl_encoder_choices else "", "qwen25vl_encoder")
return (
gr.update(visible=show_qwen), # qwen_components_group
gr.update(visible=show_z_image_turbo), # z_image_turbo_components_group
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""), # qwen_image_dit_path_input
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), # qwen_image_vae_path_input
gr.update(choices=scheduler_choices, value=scheduler_choices[0] if scheduler_choices else ""), # qwen_image_scheduler_path_input
gr.update(), # z_image_turbo_dit_path_input (保持不变)
gr.update(), # z_image_turbo_vae_path_input (保持不变)
gr.update(), # z_image_turbo_scheduler_path_input (保持不变)
gr.update(), # qwen3_encoder_path_input (保持不变)
gr.update(), # aspect_ratio
dit_btn_update, # qwen_image_dit_download_btn
vae_btn_update, # qwen_image_vae_download_btn
scheduler_btn_update, # qwen_image_scheduler_download_btn
qwen25vl_btn_update, # qwen25vl_encoder_download_btn
gr.update(), # z_image_turbo_dit_download_btn
gr.update(), # z_image_turbo_vae_download_btn
gr.update(), # z_image_turbo_scheduler_download_btn
gr.update(), # qwen3_encoder_download_btn
)
# 绑定模型类型变化事件
model_type_input.change(
fn=on_model_type_change,
inputs=[model_type_input, model_path_input],
outputs=[
qwen_components_group,
z_image_turbo_components_group,
qwen_image_dit_path_input,
qwen_image_vae_path_input,
qwen_image_scheduler_path_input,
z_image_turbo_dit_path_input,
z_image_turbo_vae_path_input,
z_image_turbo_scheduler_path_input,
qwen3_encoder_path_input,
aspect_ratio,
qwen_image_dit_download_btn,
qwen_image_vae_download_btn,
qwen_image_scheduler_download_btn,
qwen25vl_encoder_download_btn,
z_image_turbo_dit_download_btn,
z_image_turbo_vae_download_btn,
z_image_turbo_scheduler_download_btn,
qwen3_encoder_download_btn,
],
)
def update_image_path_and_gallery(files):
if files is None or len(files) == 0:
return "", []
paths = [f.name if hasattr(f, "name") else f for f in files]
return ",".join(paths), paths
image_files.change(
fn=update_image_path_and_gallery,
inputs=[image_files],
outputs=[image_path, image_gallery],
)
with gr.Row():
prompt = gr.Textbox(
label=t("prompt", lang),
lines=3,
placeholder=t("prompt_placeholder", lang),
max_lines=5,
)
negative_prompt = gr.Textbox(
label=t("negative_prompt", lang),
lines=3,
placeholder=t("negative_prompt_placeholder", lang),
max_lines=5,
value="",
)
with gr.Row():
seed = gr.Slider(
label=t("random_seed", lang),
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
infer_steps = gr.Slider(
label=t("infer_steps", lang),
minimum=1,
maximum=100,
step=1,
value=8,
info=t("infer_steps_image_info", lang),
)
# aspect_ratio 已在上面定义,这里不需要重复定义
cfg_scale = gr.Slider(
label=t("cfg_scale", lang),
minimum=1,
maximum=10,
step=1,
value=1,
info=t("cfg_scale_info", lang),
)
# 模型类型变化时更新 cfg_scale 和 infer_steps(Z-Image-Turbo 默认值)
def on_model_type_change_for_defaults(model_type_val):
if model_type_val == "Z-Image-Turbo":
return (
gr.update(value=9), # infer_steps
gr.update(value=1), # cfg_scale
)
else:
return (
gr.update(), # infer_steps (保持不变)
gr.update(), # cfg_scale (保持不变)
)
# 绑定模型类型变化事件,更新默认值
model_type_input.change(
fn=on_model_type_change_for_defaults,
inputs=[model_type_input],
outputs=[infer_steps, cfg_scale],
)
save_result_path = gr.Textbox(
label=t("output_image_path", lang),
value=generate_unique_filename(output_dir, is_image=True),
info=t("output_image_path_info", lang),
visible=False,
)
# 右侧:输出区域
with gr.Column(scale=4):
with gr.Accordion(t("output_result", lang), open=True, elem_classes=["output-video"]):
output_image = gr.Image(
label=t("output_image", lang),
height=600,
show_label=False,
)
infer_btn = gr.Button(
t("generate_image", lang),
variant="primary",
size="lg",
elem_classes=["generate-btn"],
)
def run_inference_wrapper(
prompt_val,
negative_prompt_val,
save_result_path_val,
infer_steps_val,
seed_val,
cfg_scale_val,
model_path_val,
model_type_val,
task_type_val,
image_path_val,
qwen_image_dit_path_val,
qwen_image_vae_path_val,
qwen_image_scheduler_path_val,
qwen25vl_encoder_path_val,
z_image_turbo_dit_path_val,
z_image_turbo_vae_path_val,
z_image_turbo_scheduler_path_val,
qwen3_encoder_path_val,
aspect_ratio_val,
qwen_image_use_lora_val,
qwen_image_lora_path_val,
qwen_image_lora_strength_val,
z_image_turbo_use_lora_val,
z_image_turbo_lora_path_val,
z_image_turbo_lora_strength_val,
):
# 根据模型类型传递不同的参数
if model_type_val == "Z-Image-Turbo":
result = run_inference(
prompt=prompt_val,
negative_prompt=negative_prompt_val,
save_result_path=save_result_path_val,
infer_steps=infer_steps_val,
seed=seed_val,
cfg_scale=cfg_scale_val,
model_path_input=model_path_val,
model_type_input=model_type_val,
task_type_input=task_type_val,
image_path=None, # Z-Image-Turbo 只支持 t2i,不需要图片
qwen_image_dit_path_input=None,
qwen_image_vae_path_input=None,
qwen_image_scheduler_path_input=None,
qwen25vl_encoder_path_input=None,
z_image_dit_path_input=z_image_turbo_dit_path_val,
z_image_vae_path_input=z_image_turbo_vae_path_val,
z_image_scheduler_path_input=z_image_turbo_scheduler_path_val,
qwen3_encoder_path_input=qwen3_encoder_path_val,
aspect_ratio=aspect_ratio_val, # Z-Image-Turbo 需要 aspect_ratio
use_lora=z_image_turbo_use_lora_val,
lora_path=z_image_turbo_lora_path_val,
lora_strength=z_image_turbo_lora_strength_val,
)
else:
result = run_inference(
prompt=prompt_val,
negative_prompt=negative_prompt_val,
save_result_path=save_result_path_val,
infer_steps=infer_steps_val,
seed=seed_val,
cfg_scale=cfg_scale_val,
model_path_input=model_path_val,
model_type_input=model_type_val,
task_type_input=task_type_val,
image_path=image_path_val,
qwen_image_dit_path_input=qwen_image_dit_path_val,
qwen_image_vae_path_input=qwen_image_vae_path_val,
qwen_image_scheduler_path_input=qwen_image_scheduler_path_val,
qwen25vl_encoder_path_input=qwen25vl_encoder_path_val,
z_image_dit_path_input=None,
z_image_vae_path_input=None,
z_image_scheduler_path_input=None,
qwen3_encoder_path_input=None,
aspect_ratio=aspect_ratio_val,
use_lora=qwen_image_use_lora_val,
lora_path=qwen_image_lora_path_val,
lora_strength=qwen_image_lora_strength_val,
)
return gr.update(value=result)
infer_btn.click(
fn=run_inference_wrapper,
inputs=[
prompt,
negative_prompt,
save_result_path,
infer_steps,
seed,
cfg_scale,
model_path_input,
model_type_input,
task_type_input,
image_path,
qwen_image_dit_path_input,
qwen_image_vae_path_input,
qwen_image_scheduler_path_input,
qwen25vl_encoder_path_input,
z_image_turbo_dit_path_input,
z_image_turbo_vae_path_input,
z_image_turbo_scheduler_path_input,
qwen3_encoder_path_input,
aspect_ratio,
qwen_image_use_lora_input,
qwen_image_lora_path_input,
qwen_image_lora_strength_input,
z_image_turbo_use_lora_input,
z_image_turbo_lora_path_input,
z_image_turbo_lora_strength_input,
],
outputs=[output_image],
)
"""模型选择相关函数模块"""
import os
from utils.model_utils import (
HF_AVAILABLE,
format_model_choice,
get_hf_models,
scan_model_path_contents,
sort_model_choices,
)
def _get_models_from_repos(repo_ids, prefix_filter=None, keyword_filter=None):
"""从多个仓库获取模型并合并"""
all_models = []
for repo_id in repo_ids:
models = get_hf_models(repo_id, prefix_filter=prefix_filter, keyword_filter=keyword_filter) if HF_AVAILABLE else []
all_models.extend(models)
return list(set(all_models))
def _filter_and_format_models(hf_models, model_path, is_valid_func, require_safetensors=True, additional_filters=None):
"""通用的模型过滤和格式化函数"""
# 延迟导入避免循环依赖
from utils.model_utils import is_fp8_supported_gpu
fp8_supported = is_fp8_supported_gpu()
# 筛选 HF 模型
valid_hf_models = []
for m in hf_models:
if not is_valid_func(m, fp8_supported):
continue
if require_safetensors:
if m.endswith(".safetensors") or "_split" in m.lower():
valid_hf_models.append(m)
else:
valid_hf_models.append(m)
# 检查本地已存在的模型
contents = scan_model_path_contents(model_path)
local_models = []
if require_safetensors:
dir_choices = [d for d in contents["dirs"] if is_valid_func(d, fp8_supported) and ("_split" in d.lower() or d in contents["safetensors_dirs"])]
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and is_valid_func(f, fp8_supported)]
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if is_valid_func(d, fp8_supported)]
local_models = dir_choices + safetensors_choices + safetensors_dir_choices
else:
for item in contents["dirs"] + contents["files"]:
if is_valid_func(item, fp8_supported):
local_models.append(item)
# 应用额外过滤器
if additional_filters:
valid_hf_models = [m for m in valid_hf_models if additional_filters(m)]
local_models = [m for m in local_models if additional_filters(m)]
# 合并 HF 和本地模型,去重,并按优先级排序
all_models = sort_model_choices(list(set(valid_hf_models + local_models)))
# 格式化选项,添加下载状态
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_dit_choices(model_path, model_type="wan2.1", task_type=None, is_distill=None):
"""获取 Diffusion 模型可选项"""
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
# 根据模型类型和是否 distill 选择仓库
if model_type == "wan2.1":
if is_distill is True:
repo_ids = ["lightx2v/wan2.1-Distill-Models"]
elif is_distill is False:
repo_ids = ["lightx2v/wan2.1-Official-Models"]
else:
repo_ids = ["lightx2v/wan2.1-Distill-Models", "lightx2v/wan2.1-Official-Models"]
else: # wan2.2
if is_distill is True:
repo_ids = ["lightx2v/wan2.2-Distill-Models"]
elif is_distill is False:
repo_ids = ["lightx2v/wan2.2-Official-Models"]
else:
repo_ids = ["lightx2v/wan2.2-Distill-Models", "lightx2v/wan2.2-Official-Models"]
hf_models = _get_models_from_repos(repo_ids, prefix_filter=model_type)
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower:
return False
if model_type == "wan2.1":
if "wan2.1" not in name_lower:
return False
else:
if "wan2.2" not in name_lower:
return False
if task_type and task_type.lower() not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
return _filter_and_format_models(hf_models, model_path, is_valid)
def get_high_noise_choices(model_path, model_type="wan2.2", task_type=None, is_distill=None):
"""获取高噪模型可选项"""
if is_distill is True:
repo_ids = ["lightx2v/wan2.2-Distill-Models"]
elif is_distill is False:
repo_ids = ["lightx2v/wan2.2-Official-Models"]
else:
repo_ids = ["lightx2v/wan2.2-Distill-Models", "lightx2v/wan2.2-Official-Models"]
hf_models = _get_models_from_repos(repo_ids, keyword_filter="high_noise")
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower:
return False
if model_type.lower() not in name_lower:
return False
if task_type and task_type.lower() not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
return _filter_and_format_models(hf_models, model_path, is_valid)
def get_low_noise_choices(model_path, model_type="wan2.2", task_type=None, is_distill=None):
"""获取低噪模型可选项"""
if is_distill is True:
repo_ids = ["lightx2v/wan2.2-Distill-Models"]
elif is_distill is False:
repo_ids = ["lightx2v/wan2.2-Official-Models"]
else:
repo_ids = ["lightx2v/wan2.2-Distill-Models", "lightx2v/wan2.2-Official-Models"]
hf_models = _get_models_from_repos(repo_ids, keyword_filter="low_noise")
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower:
return False
if model_type.lower() not in name_lower:
return False
if task_type and task_type.lower() not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
return _filter_and_format_models(hf_models, model_path, is_valid)
def get_t5_model_choices(model_path):
"""获取 T5 模型可选项"""
repo_id = "lightx2v/Encoders"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower or name == "google":
return False
if not fp8_supported and "fp8" in name_lower:
return False
return ("t5" in name_lower) and name.endswith(".safetensors")
return _filter_and_format_models(hf_models, model_path, is_valid)
def get_t5_tokenizer_choices(model_path):
"""获取 T5 Tokenizer 可选项"""
contents = scan_model_path_contents(model_path)
dir_choices = ["google"] if "google" in contents["dirs"] else []
repo_id = "lightx2v/Encoders"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
hf_google = ["google"] if "google" in hf_models else []
all_models = sorted(set(hf_google + dir_choices))
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_clip_model_choices(model_path):
"""获取 CLIP 模型可选项"""
repo_id = "lightx2v/Encoders"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower or name == "xlm-roberta-large":
return False
if not fp8_supported and "fp8" in name_lower:
return False
return ("clip" in name_lower) and name.endswith(".safetensors")
return _filter_and_format_models(hf_models, model_path, is_valid)
def get_clip_tokenizer_choices(model_path):
"""获取 CLIP Tokenizer 可选项"""
contents = scan_model_path_contents(model_path)
dir_choices = ["xlm-roberta-large"] if "xlm-roberta-large" in contents["dirs"] else []
repo_id = "lightx2v/Encoders"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
hf_xlm = ["xlm-roberta-large"] if "xlm-roberta-large" in hf_models else []
all_models = sorted(set(hf_xlm + dir_choices))
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_vae_choices(model_path):
"""获取 VAE 编码器可选项,只返回 Wan2.1_VAE.safetensors"""
encoder_name = "Wan2.1_VAE.safetensors"
return [format_model_choice(encoder_name, model_path)]
def get_qwen_image_dit_choices(model_path):
"""获取 Qwen Image Edit Diffusion 模型可选项
注意:Diffusion 模型只从 model_path/Qwen-Image-Edit-2511 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "lightx2v/Qwen-Image-Edit-2511-Lightning"
model_subdir = "Qwen-Image-Edit-2511"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
if "qwen_image_edit_2511" not in name_lower:
return False
return name.endswith("lightning.safetensors") or name.endswith("_split") or "lightning_split" in name_lower
# 筛选 HF 模型
from utils.model_utils import is_fp8_supported_gpu
fp8_supported = is_fp8_supported_gpu()
valid_hf_models = [m for m in hf_models if is_valid(m, fp8_supported)]
# 本地扫描:只从 model_path/Qwen-Image-Edit-2511 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
# 检查子目录中的文件和目录
for item in subdir_contents["files"] + subdir_contents["dirs"]:
if is_valid(item, fp8_supported):
# 添加子目录路径前缀
local_models.append(f"{model_subdir}/{item}")
# 合并 HF 和本地模型,添加子目录路径前缀
all_models = []
for m in valid_hf_models:
# HF 模型也需要添加子目录路径,确保下载到子目录
all_models.append(f"{model_subdir}/{m}")
all_models.extend(local_models)
all_models = sorted(set(all_models))
# 如果没有找到,返回空列表(让用户知道需要下载)
if not all_models:
# 至少返回一个选项,用于下载
all_models = [f"{model_subdir}/"] if valid_hf_models else []
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_qwen_image_vae_choices(model_path):
"""获取 Qwen Image Edit VAE 可选项
注意:VAE 只从 model_path/Qwen-Image-Edit-2511/vae 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "Qwen/Qwen-Image-Edit-2511"
model_subdir = "Qwen-Image-Edit-2511"
# 从 HF/MS 缓存获取模型列表(get_hf_models 实际上支持 MS,会先尝试 MS)
remote_models = get_hf_models(repo_id) if HF_AVAILABLE else []
# 从远程仓库获取的模型选项,需要添加子目录路径
valid_remote_models = []
for m in remote_models:
if m.lower() == "vae":
# 添加子目录路径,确保下载到子目录
valid_remote_models.append(f"{model_subdir}/vae")
# 本地扫描:只从 model_path/Qwen-Image-Edit-2511/vae 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
# 检查子目录是否存在
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
if "vae" in [d.lower() for d in subdir_contents["dirs"]]:
local_models.append(f"{model_subdir}/vae")
all_models = sorted(set(valid_remote_models + local_models))
# 如果没有找到,返回默认选项(包含子目录路径,用于下载)
if not all_models:
all_models = [f"{model_subdir}/vae"]
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices
def get_qwen_image_scheduler_choices(model_path):
"""获取 Qwen Image Edit Scheduler 可选项
注意:Scheduler 只从 model_path/Qwen-Image-Edit-2511/scheduler 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "Qwen/Qwen-Image-Edit-2511"
model_subdir = "Qwen-Image-Edit-2511"
# 从 HF/MS 缓存获取模型列表(get_hf_models 实际上支持 MS,会先尝试 MS)
remote_models = get_hf_models(repo_id) if HF_AVAILABLE else []
# 从远程仓库获取的模型选项,需要添加子目录路径
valid_remote_models = []
for m in remote_models:
if m.lower() == "scheduler":
# 添加子目录路径,确保下载到子目录
valid_remote_models.append(f"{model_subdir}/scheduler")
# 本地扫描:只从 model_path/Qwen-Image-Edit-2511/scheduler 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
# 检查子目录是否存在
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
if "scheduler" in [d.lower() for d in subdir_contents["dirs"]]:
local_models.append(f"{model_subdir}/scheduler")
all_models = sorted(set(valid_remote_models + local_models))
# 如果没有找到,返回默认选项(包含子目录路径,用于下载)
if not all_models:
all_models = [f"{model_subdir}/scheduler"]
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices
def get_qwen25vl_encoder_choices(model_path):
"""获取 Qwen25-VL 编码器可选项"""
repo_id = "lightx2v/Encoders"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
valid_hf_models = [m for m in hf_models if "qwen25-vl-4bit-gptq" in m.lower() or "qwen25_vl_4bit_gptq" in m.lower()]
contents = scan_model_path_contents(model_path)
local_models = [d for d in contents["dirs"] if "qwen25-vl-4bit-gptq" in d.lower() or "qwen25_vl_4bit_gptq" in d.lower()]
all_models = sorted(set(valid_hf_models + local_models))
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_qwen_image_2512_dit_choices(model_path):
"""获取 Qwen Image 2512 Diffusion 模型可选项
注意:Diffusion 模型只从 model_path/Qwen-Image-2512 目录检索,如果不存在可以从 HF/MS 下载到子目录
只显示包含 "qwen_image_2512" 字段的模型
"""
repo_id_lightning = "lightx2v/Qwen-Image-2512-Lightning"
repo_id_official = "Qwen/Qwen-Image-2512"
model_subdir = "Qwen-Image-2512"
hf_models_lightning = get_hf_models(repo_id_lightning) if HF_AVAILABLE else []
hf_models_official = get_hf_models(repo_id_official) if HF_AVAILABLE else []
# 合并两个仓库的模型
hf_models = hf_models_lightning + hf_models_official
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
# 只包含 "qwen_image_2512" 字段的模型
if "qwen_image_2512" not in name_lower:
return False
return name.endswith("lightning.safetensors") or name.endswith("_split") or "lightning_split" in name_lower or name.endswith(".safetensors")
# 筛选 HF 模型
from utils.model_utils import is_fp8_supported_gpu
fp8_supported = is_fp8_supported_gpu()
valid_hf_models = [m for m in hf_models if is_valid(m, fp8_supported)]
# 本地扫描:只从 model_path/Qwen-Image-2512 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
# 检查子目录中的文件和目录
for item in subdir_contents["files"] + subdir_contents["dirs"]:
if is_valid(item, fp8_supported):
# 添加子目录路径前缀
local_models.append(f"{model_subdir}/{item}")
# 合并 HF 和本地模型,添加子目录路径前缀
all_models = []
for m in valid_hf_models:
# HF 模型也需要添加子目录路径,确保下载到子目录
all_models.append(f"{model_subdir}/{m}")
all_models.extend(local_models)
all_models = sorted(set(all_models))
# 如果没有找到,返回空列表(让用户知道需要下载)
if not all_models:
# 至少返回一个选项,用于下载
all_models = [f"{model_subdir}/"] if valid_hf_models else []
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_qwen_image_2512_vae_choices(model_path):
"""获取 Qwen Image 2512 VAE 可选项
注意:VAE 只从 model_path/Qwen-Image-2512/vae 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "Qwen/Qwen-Image-2512"
model_subdir = "Qwen-Image-2512"
# 从 HF/MS 缓存获取模型列表(get_hf_models 实际上支持 MS,会先尝试 MS)
remote_models = get_hf_models(repo_id) if HF_AVAILABLE else []
# 从远程仓库获取的模型选项,需要添加子目录路径
valid_remote_models = []
for m in remote_models:
if m.lower() == "vae":
# 添加子目录路径,确保下载到子目录
valid_remote_models.append(f"{model_subdir}/vae")
# 本地扫描:只从 model_path/Qwen-Image-2512/vae 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
# 检查子目录是否存在
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
if "vae" in [d.lower() for d in subdir_contents["dirs"]]:
local_models.append(f"{model_subdir}/vae")
all_models = sorted(set(valid_remote_models + local_models))
# 如果没有找到,返回默认选项(包含子目录路径,用于下载)
if not all_models:
all_models = [f"{model_subdir}/vae"]
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices
def get_qwen_image_2512_scheduler_choices(model_path):
"""获取 Qwen Image 2512 Scheduler 可选项
注意:Scheduler 只从 model_path/Qwen-Image-2512/scheduler 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "Qwen/Qwen-Image-2512"
model_subdir = "Qwen-Image-2512"
# 从 HF/MS 缓存获取模型列表(get_hf_models 实际上支持 MS,会先尝试 MS)
remote_models = get_hf_models(repo_id) if HF_AVAILABLE else []
# 从远程仓库获取的模型选项,需要添加子目录路径
valid_remote_models = []
for m in remote_models:
if m.lower() == "scheduler":
# 添加子目录路径,确保下载到子目录
valid_remote_models.append(f"{model_subdir}/scheduler")
# 本地扫描:只从 model_path/Qwen-Image-2512/scheduler 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
# 检查子目录是否存在
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
if "scheduler" in [d.lower() for d in subdir_contents["dirs"]]:
local_models.append(f"{model_subdir}/scheduler")
all_models = sorted(set(valid_remote_models + local_models))
# 如果没有找到,返回默认选项(包含子目录路径,用于下载)
if not all_models:
all_models = [f"{model_subdir}/scheduler"]
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices
def get_z_image_turbo_dit_choices(model_path):
"""获取 Z-Image-Turbo Diffusion 模型可选项
注意:Diffusion 模型只从 model_path/Z-Image-Turbo 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "lightx2v/Z-Image-Turbo-Quantized"
model_subdir = "Z-Image-Turbo"
hf_models = get_hf_models(repo_id) if HF_AVAILABLE else []
def is_valid(name, fp8_supported):
name_lower = name.lower()
if "comfyui" in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
if "z_image_turbo" not in name_lower and "z-image-turbo" not in name_lower:
return False
return name.endswith(".safetensors") or name.endswith("_split") or "_split" in name_lower
# 筛选 HF 模型
from utils.model_utils import is_fp8_supported_gpu
fp8_supported = is_fp8_supported_gpu()
valid_hf_models = [m for m in hf_models if is_valid(m, fp8_supported)]
# 本地扫描:只从 model_path/Z-Image-Turbo 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
# 检查子目录中的文件和目录
for item in subdir_contents["files"] + subdir_contents["dirs"]:
if is_valid(item, fp8_supported):
# 添加子目录路径前缀
local_models.append(f"{model_subdir}/{item}")
# 合并 HF 和本地模型,添加子目录路径前缀
all_models = []
for m in valid_hf_models:
# HF 模型也需要添加子目录路径,确保下载到子目录
all_models.append(f"{model_subdir}/{m}")
all_models.extend(local_models)
all_models = sorted(set(all_models))
# 如果没有找到,返回空列表(让用户知道需要下载)
if not all_models:
# 至少返回一个选项,用于下载
all_models = [f"{model_subdir}/"] if valid_hf_models else []
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices if formatted_choices else [""]
def get_z_image_turbo_vae_choices(model_path):
"""获取 Z-Image-Turbo VAE 可选项
注意:VAE 只从 model_path/Z-Image-Turbo/vae 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "Tongyi-MAI/Z-Image-Turbo"
model_subdir = "Z-Image-Turbo"
# 从 HF/MS 缓存获取模型列表(get_hf_models 实际上支持 MS,会先尝试 MS)
remote_models = get_hf_models(repo_id) if HF_AVAILABLE else []
# 从远程仓库获取的模型选项,需要添加子目录路径
valid_remote_models = []
for m in remote_models:
if m.lower() == "vae":
# 添加子目录路径,确保下载到子目录
valid_remote_models.append(f"{model_subdir}/vae")
# 本地扫描:只从 model_path/Z-Image-Turbo/vae 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
# 检查子目录是否存在
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
if "vae" in [d.lower() for d in subdir_contents["dirs"]]:
local_models.append(f"{model_subdir}/vae")
all_models = sorted(set(valid_remote_models + local_models))
# 如果没有找到,返回默认选项(包含子目录路径,用于下载)
if not all_models:
all_models = [f"{model_subdir}/vae"]
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices
def get_z_image_turbo_scheduler_choices(model_path):
"""获取 Z-Image-Turbo Scheduler 可选项
注意:Scheduler 只从 model_path/Z-Image-Turbo/scheduler 目录检索,如果不存在可以从 HF/MS 下载到子目录
"""
repo_id = "Tongyi-MAI/Z-Image-Turbo"
model_subdir = "Z-Image-Turbo"
# 从 HF/MS 缓存获取模型列表(get_hf_models 实际上支持 MS,会先尝试 MS)
remote_models = get_hf_models(repo_id) if HF_AVAILABLE else []
# 从远程仓库获取的模型选项,需要添加子目录路径
valid_remote_models = []
for m in remote_models:
if m.lower() == "scheduler":
# 添加子目录路径,确保下载到子目录
valid_remote_models.append(f"{model_subdir}/scheduler")
# 本地扫描:只从 model_path/Z-Image-Turbo/scheduler 检索,不从 model_path 根目录检索
subdir_path = os.path.join(model_path, model_subdir)
local_models = []
# 检查子目录是否存在
if os.path.exists(subdir_path) and os.path.isdir(subdir_path):
subdir_contents = scan_model_path_contents(subdir_path)
if "scheduler" in [d.lower() for d in subdir_contents["dirs"]]:
local_models.append(f"{model_subdir}/scheduler")
all_models = sorted(set(valid_remote_models + local_models))
# 如果没有找到,返回默认选项(包含子目录路径,用于下载)
if not all_models:
all_models = [f"{model_subdir}/scheduler"]
formatted_choices = [format_model_choice(m, model_path) for m in all_models]
return formatted_choices
def get_qwen3_encoder_choices(model_path):
"""获取 Qwen3 编码器可选项
注意:Qwen3 编码器只检测整个仓库目录(Qwen3-4B-GPTQ-Int4)是否存在,不存在可以下载
"""
repo_id = "JunHowie/Qwen3-4B-GPTQ-Int4"
repo_name = repo_id.split("/")[-1] # "Qwen3-4B-GPTQ-Int4"
# 只检测整个仓库目录是否存在
repo_path = os.path.join(model_path, repo_name)
if os.path.exists(repo_path) and os.path.isdir(repo_path):
# 如果目录存在,返回目录名
return [format_model_choice(repo_name, model_path)]
else:
# 如果目录不存在,也返回目录名(用于下载)
return [format_model_choice(repo_name, model_path)]
"""模型组件构建模块
为不同的模型类型创建 UI 组件
这些函数需要在 Gradio 的 with 块内调用
"""
import os
import gradio as gr
from utils.i18n import t
from utils.model_choices import (
get_dit_choices,
get_high_noise_choices,
get_low_noise_choices,
get_qwen3_encoder_choices,
get_qwen25vl_encoder_choices,
get_qwen_image_2512_dit_choices,
get_qwen_image_2512_scheduler_choices,
get_qwen_image_2512_vae_choices,
get_qwen_image_dit_choices,
get_qwen_image_scheduler_choices,
get_qwen_image_vae_choices,
get_z_image_turbo_dit_choices,
get_z_image_turbo_scheduler_choices,
get_z_image_turbo_vae_choices,
)
from utils.model_handlers import (
download_model_handler,
update_model_status,
)
def get_lora_choices(model_path):
"""获取 LoRA 模型可选项,从 model_path/loras 目录检索
Args:
model_path: 模型根路径
Returns:
list: LoRA 文件列表,如果没有则返回 [""]
"""
loras_dir = os.path.join(model_path, "loras")
if not os.path.exists(loras_dir):
return [""]
lora_files = []
# 支持常见的 LoRA 文件格式
lora_extensions = [".safetensors", ".pt", ".pth", ".ckpt"]
for item in os.listdir(loras_dir):
item_path = os.path.join(loras_dir, item)
if os.path.isfile(item_path):
# 检查是否是 LoRA 文件
if any(item.lower().endswith(ext) for ext in lora_extensions):
lora_files.append(item)
# 按文件名排序
lora_files.sort()
return lora_files if lora_files else [""]
def build_wan21_components(model_path, model_path_input, model_type_input, task_type_input, download_source_input, update_funcs, download_funcs, lang="zh"):
"""构建 wan2.1 模型相关组件
必须在 Gradio 的 with 块内调用
Returns:
dict: 包含所有 wan2.1 相关组件的字典
"""
# wan2.1:Diffusion模型
with gr.Column(elem_classes=["diffusion-model-group"]) as wan21_row:
with gr.Row():
with gr.Column(scale=5):
dit_choices_init = get_dit_choices(model_path, "wan2.1", "i2v")
dit_path_input = gr.Dropdown(
label="🎨 Diffusion模型",
choices=dit_choices_init,
value=dit_choices_init[0] if dit_choices_init else "",
allow_custom_value=True,
visible=True,
)
with gr.Column(scale=1, min_width=150):
# 初始化时检查模型状态,确定下载按钮的初始可见性
from utils.model_utils import check_model_exists, extract_model_name
dit_btn_visible = False
if dit_choices_init:
first_choice = dit_choices_init[0]
actual_name = extract_model_name(first_choice)
dit_exists = check_model_exists(model_path, actual_name)
dit_btn_visible = not dit_exists
dit_download_btn = gr.Button("📥 下载", visible=dit_btn_visible, size="sm", variant="secondary")
dit_download_status = gr.Markdown("", visible=False)
lora_choices_init = get_lora_choices(model_path)
with gr.Row():
with gr.Column(scale=1):
use_lora = gr.Checkbox(
label=t("use_lora", lang),
value=False,
)
lora_path_input = gr.Dropdown(
label=t("lora", lang),
choices=lora_choices_init,
value=lora_choices_init[0] if lora_choices_init and lora_choices_init[0] else "",
allow_custom_value=True,
visible=False,
info=t("lora_info", lang),
)
lora_strength = gr.Slider(
label=t("lora_strength", lang),
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
visible=False,
info=t("lora_strength_info", lang),
)
# 绑定事件
update_dit_status = update_funcs["update_dit_status"]
download_dit_model = download_funcs["download_dit_model"]
dit_path_input.change(
fn=lambda mp, mn, mt: update_dit_status(mp, mn, mt),
inputs=[model_path_input, dit_path_input, model_type_input],
outputs=[dit_download_btn],
)
dit_download_btn.click(
fn=download_dit_model,
inputs=[model_path_input, dit_path_input, model_type_input, task_type_input, download_source_input],
outputs=[dit_download_status, dit_download_btn, dit_path_input],
)
# LoRA 开关变化时显示/隐藏相关组件
def on_use_lora_change(use_lora_val):
return (
gr.update(visible=use_lora_val), # lora_path_input
gr.update(visible=use_lora_val), # lora_strength
)
use_lora.change(
fn=on_use_lora_change,
inputs=[use_lora],
outputs=[lora_path_input, lora_strength],
)
# 当 model_path 变化时更新 LoRA 选择
def update_lora_choices(model_path_val):
lora_choices = get_lora_choices(model_path_val)
return gr.update(choices=lora_choices, value=lora_choices[0] if lora_choices and lora_choices[0] else "")
model_path_input.change(
fn=update_lora_choices,
inputs=[model_path_input],
outputs=[lora_path_input],
)
return {
"wan21_row": wan21_row,
"dit_path_input": dit_path_input,
"dit_download_btn": dit_download_btn,
"dit_download_status": dit_download_status,
"use_lora": use_lora,
"lora_path_input": lora_path_input,
"lora_strength": lora_strength,
}
def build_wan22_components(model_path, model_path_input, task_type_input, download_source_input, update_funcs, download_funcs, lang="zh"):
"""构建 wan2.2 模型相关组件
必须在 Gradio 的 with 块内调用
Returns:
dict: 包含所有 wan2.2 相关组件的字典
"""
# wan2.2 专用:高噪模型 + 低噪模型
with gr.Row(visible=False, elem_classes=["wan22-row"]) as wan22_row:
with gr.Column(scale=1):
high_noise_choices_init = get_high_noise_choices(model_path, "wan2.2", "i2v")
high_noise_path_input = gr.Dropdown(
label="🔊 高噪模型",
choices=high_noise_choices_init,
value=high_noise_choices_init[0] if high_noise_choices_init else "",
allow_custom_value=True,
)
# 初始化时检查模型状态
from utils.model_utils import check_model_exists, extract_model_name
high_noise_btn_visible = False
if high_noise_choices_init:
first_choice = high_noise_choices_init[0]
actual_name = extract_model_name(first_choice)
high_noise_exists = check_model_exists(model_path, actual_name)
high_noise_btn_visible = not high_noise_exists
high_noise_download_btn = gr.Button("📥 下载", visible=high_noise_btn_visible, size="sm", variant="secondary")
high_noise_download_status = gr.Markdown("", visible=False)
with gr.Column(scale=1):
low_noise_choices_init = get_low_noise_choices(model_path, "wan2.2", "i2v")
low_noise_path_input = gr.Dropdown(
label="🔇 低噪模型",
choices=low_noise_choices_init,
value=low_noise_choices_init[0] if low_noise_choices_init else "",
allow_custom_value=True,
)
# 初始化时检查模型状态
low_noise_btn_visible = False
if low_noise_choices_init:
first_choice = low_noise_choices_init[0]
actual_name = extract_model_name(first_choice)
low_noise_exists = check_model_exists(model_path, actual_name)
low_noise_btn_visible = not low_noise_exists
low_noise_download_btn = gr.Button("📥 下载", visible=low_noise_btn_visible, size="sm", variant="secondary")
low_noise_download_status = gr.Markdown("", visible=False)
# LoRA 组件(Wan2.2 需要为 high_noise 和 low_noise 分别配置)
lora_choices_init = get_lora_choices(model_path)
with gr.Row():
with gr.Column(scale=1):
use_lora = gr.Checkbox(
label=t("use_lora", lang),
value=False,
)
# High Noise LoRA
high_noise_lora_path_input = gr.Dropdown(
label=t("high_noise_lora", lang),
choices=lora_choices_init,
value=lora_choices_init[0] if lora_choices_init and lora_choices_init[0] else "",
allow_custom_value=True,
visible=False,
info=t("high_noise_lora_info", lang),
)
high_noise_lora_strength = gr.Slider(
label=t("high_noise_lora_strength", lang),
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
visible=False,
info=t("high_noise_lora_strength_info", lang),
)
# Low Noise LoRA
low_noise_lora_path_input = gr.Dropdown(
label=t("low_noise_lora", lang),
choices=lora_choices_init,
value=lora_choices_init[0] if lora_choices_init and lora_choices_init[0] else "",
allow_custom_value=True,
visible=False,
info=t("low_noise_lora_info", lang),
)
low_noise_lora_strength = gr.Slider(
label=t("low_noise_lora_strength", lang),
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
visible=False,
info=t("low_noise_lora_strength_info", lang),
)
# 绑定事件
update_high_noise_status = update_funcs["update_high_noise_status"]
update_low_noise_status = update_funcs["update_low_noise_status"]
download_high_noise_model = download_funcs["download_high_noise_model"]
download_low_noise_model = download_funcs["download_low_noise_model"]
high_noise_path_input.change(
fn=update_high_noise_status,
inputs=[model_path_input, high_noise_path_input],
outputs=[high_noise_download_btn],
)
low_noise_path_input.change(
fn=update_low_noise_status,
inputs=[model_path_input, low_noise_path_input],
outputs=[low_noise_download_btn],
)
high_noise_download_btn.click(
fn=download_high_noise_model,
inputs=[model_path_input, high_noise_path_input, task_type_input, download_source_input],
outputs=[high_noise_download_status, high_noise_download_btn, high_noise_path_input],
)
low_noise_download_btn.click(
fn=download_low_noise_model,
inputs=[model_path_input, low_noise_path_input, task_type_input, download_source_input],
outputs=[low_noise_download_status, low_noise_download_btn, low_noise_path_input],
)
# LoRA 开关变化时显示/隐藏相关组件
def on_use_lora_change(use_lora_val):
return (
gr.update(visible=use_lora_val), # high_noise_lora_path_input
gr.update(visible=use_lora_val), # high_noise_lora_strength
gr.update(visible=use_lora_val), # low_noise_lora_path_input
gr.update(visible=use_lora_val), # low_noise_lora_strength
)
use_lora.change(
fn=on_use_lora_change,
inputs=[use_lora],
outputs=[high_noise_lora_path_input, high_noise_lora_strength, low_noise_lora_path_input, low_noise_lora_strength],
)
# 当 model_path 变化时更新 LoRA 选择
def update_lora_choices(model_path_val):
lora_choices = get_lora_choices(model_path_val)
return (
gr.update(choices=lora_choices, value=lora_choices[0] if lora_choices and lora_choices[0] else ""), # high_noise_lora_path_input
gr.update(choices=lora_choices, value=lora_choices[0] if lora_choices and lora_choices[0] else ""), # low_noise_lora_path_input
)
model_path_input.change(
fn=update_lora_choices,
inputs=[model_path_input],
outputs=[high_noise_lora_path_input, low_noise_lora_path_input],
)
return {
"wan22_row": wan22_row,
"high_noise_path_input": high_noise_path_input,
"high_noise_download_btn": high_noise_download_btn,
"high_noise_download_status": high_noise_download_status,
"low_noise_path_input": low_noise_path_input,
"low_noise_download_btn": low_noise_download_btn,
"low_noise_download_status": low_noise_download_status,
"use_lora": use_lora,
"high_noise_lora_path_input": high_noise_lora_path_input,
"high_noise_lora_strength": high_noise_lora_strength,
"low_noise_lora_path_input": low_noise_lora_path_input,
"low_noise_lora_strength": low_noise_lora_strength,
}
def build_qwen_image_components(model_path, model_path_input, download_source_input, model_type_input, lang="zh"):
"""构建 Qwen-Image-Edit-2511/Qwen-Image-2512 模型相关组件(用于图片页面)
必须在 Gradio 的 with 块内调用
Args:
model_path: 模型路径
model_path_input: 模型路径输入组件
download_source_input: 下载源输入组件
model_type_input: 模型类型输入组件(用于动态更新)
lang: 语言代码,默认为 "zh"
Returns:
dict: 包含所有 Qwen-Image 相关组件的字典
"""
# 根据模型类型选择函数
def get_dit_choices_func(model_path_val):
model_type_val = model_type_input.value if hasattr(model_type_input, "value") else "Qwen-Image-Edit-2511"
if model_type_val == "Qwen-Image-2512":
return get_qwen_image_2512_dit_choices(model_path_val)
else:
return get_qwen_image_dit_choices(model_path_val)
def get_vae_choices_func(model_path_val):
model_type_val = model_type_input.value if hasattr(model_type_input, "value") else "Qwen-Image-Edit-2511"
if model_type_val == "Qwen-Image-2512":
return get_qwen_image_2512_vae_choices(model_path_val)
else:
return get_qwen_image_vae_choices(model_path_val)
def get_scheduler_choices_func(model_path_val):
model_type_val = model_type_input.value if hasattr(model_type_input, "value") else "Qwen-Image-Edit-2511"
if model_type_val == "Qwen-Image-2512":
return get_qwen_image_2512_scheduler_choices(model_path_val)
else:
return get_qwen_image_scheduler_choices(model_path_val)
qwen_image_dit_choices_init = get_dit_choices_func(model_path)
qwen_image_vae_choices_init = get_vae_choices_func(model_path)
qwen_image_scheduler_choices_init = get_scheduler_choices_func(model_path)
# 初始化时检查模型状态,确定下载按钮的初始可见性
def get_initial_download_btn_visibility(choices, model_path_val, model_type_val, base_category):
if not choices:
return False
first_choice = choices[0]
from utils.model_utils import check_model_exists, extract_model_name
actual_name = extract_model_name(first_choice)
exists = check_model_exists(model_path_val, actual_name)
return not exists
initial_model_type = model_type_input.value if hasattr(model_type_input, "value") else "Qwen-Image-Edit-2511"
dit_btn_visible = get_initial_download_btn_visibility(qwen_image_dit_choices_init, model_path, initial_model_type, "dit")
vae_btn_visible = get_initial_download_btn_visibility(qwen_image_vae_choices_init, model_path, initial_model_type, "vae")
scheduler_btn_visible = get_initial_download_btn_visibility(qwen_image_scheduler_choices_init, model_path, initial_model_type, "scheduler")
# Qwen-Image 模型配置
with gr.Column(elem_classes=["diffusion-model-group"]) as qwen_components_group:
# Diffusion 模型
with gr.Row():
with gr.Column(scale=5):
qwen_image_dit_path_input = gr.Dropdown(
label=t("diffusion_model", lang),
choices=qwen_image_dit_choices_init,
value=qwen_image_dit_choices_init[0] if qwen_image_dit_choices_init else "",
allow_custom_value=True,
)
with gr.Column(scale=1, min_width=150):
qwen_image_dit_download_btn = gr.Button(t("download", lang), visible=dit_btn_visible, size="sm", variant="secondary")
qwen_image_dit_download_status = gr.Markdown("", visible=False)
# VAE 和 Scheduler
with gr.Row():
with gr.Column(scale=1):
qwen_image_vae_path_input = gr.Dropdown(
label=t("vae", lang),
choices=qwen_image_vae_choices_init,
value=qwen_image_vae_choices_init[0] if qwen_image_vae_choices_init else "",
allow_custom_value=True,
)
qwen_image_vae_download_btn = gr.Button(t("download", lang), visible=vae_btn_visible, size="sm", variant="secondary")
qwen_image_vae_download_status = gr.Markdown("", visible=False)
with gr.Column(scale=1):
qwen_image_scheduler_path_input = gr.Dropdown(
label=t("scheduler", lang),
choices=qwen_image_scheduler_choices_init,
value=qwen_image_scheduler_choices_init[0] if qwen_image_scheduler_choices_init else "",
allow_custom_value=True,
)
qwen_image_scheduler_download_btn = gr.Button(t("download", lang), visible=scheduler_btn_visible, size="sm", variant="secondary")
qwen_image_scheduler_download_status = gr.Markdown("", visible=False)
# Qwen25-VL 编码器
with gr.Row():
with gr.Column(scale=1):
qwen25vl_encoder_choices_init = get_qwen25vl_encoder_choices(model_path)
qwen25vl_encoder_path_input = gr.Dropdown(
label=t("qwen25vl_encoder", lang),
choices=qwen25vl_encoder_choices_init,
value=qwen25vl_encoder_choices_init[0] if qwen25vl_encoder_choices_init else "",
allow_custom_value=True,
)
# 初始化时检查模型状态
from utils.model_utils import check_model_exists, extract_model_name
qwen25vl_btn_visible = False
if qwen25vl_encoder_choices_init:
first_choice = qwen25vl_encoder_choices_init[0]
actual_name = extract_model_name(first_choice)
qwen25vl_exists = check_model_exists(model_path, actual_name)
qwen25vl_btn_visible = not qwen25vl_exists
qwen25vl_encoder_download_btn = gr.Button(t("download", lang), visible=qwen25vl_btn_visible, size="sm", variant="secondary")
qwen25vl_encoder_download_status = gr.Markdown("", visible=False)
# LoRA 组件
lora_choices_init = get_lora_choices(model_path)
with gr.Row():
with gr.Column(scale=1):
use_lora = gr.Checkbox(
label=t("use_lora", lang),
value=False,
)
lora_path_input = gr.Dropdown(
label=t("lora", lang),
choices=lora_choices_init,
value=lora_choices_init[0] if lora_choices_init and lora_choices_init[0] else "",
allow_custom_value=True,
visible=False,
info=t("lora_info", lang),
)
lora_strength = gr.Slider(
label=t("lora_strength", lang),
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
visible=False,
info=t("lora_strength_info", lang),
)
# LoRA 开关变化时显示/隐藏相关组件
def on_use_lora_change(use_lora_val):
return (
gr.update(visible=use_lora_val), # lora_path_input
gr.update(visible=use_lora_val), # lora_strength
)
use_lora.change(
fn=on_use_lora_change,
inputs=[use_lora],
outputs=[lora_path_input, lora_strength],
)
# 当 model_path 变化时更新 LoRA 选择
def update_lora_choices(model_path_val):
lora_choices = get_lora_choices(model_path_val)
return gr.update(choices=lora_choices, value=lora_choices[0] if lora_choices and lora_choices[0] else "")
model_path_input.change(
fn=update_lora_choices,
inputs=[model_path_input],
outputs=[lora_path_input],
)
# 模型类型变化时更新选择
def update_choices_on_model_type_change(model_type_val, model_path_val):
if model_type_val == "Qwen-Image-2512":
dit_choices = get_qwen_image_2512_dit_choices(model_path_val)
vae_choices = get_qwen_image_2512_vae_choices(model_path_val)
scheduler_choices = get_qwen_image_2512_scheduler_choices(model_path_val)
else:
dit_choices = get_qwen_image_dit_choices(model_path_val)
vae_choices = get_qwen_image_vae_choices(model_path_val)
scheduler_choices = get_qwen_image_scheduler_choices(model_path_val)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
gr.update(choices=scheduler_choices, value=scheduler_choices[0] if scheduler_choices else ""),
)
# 绑定下载按钮事件
def download_qwen_image_dit(model_path_val, model_name, model_type_val, download_source_val, progress=gr.Progress()):
# 从模型名称中提取实际名称
from utils.model_utils import extract_model_name
actual_name = extract_model_name(model_name)
actual_name_lower = actual_name.lower()
# 如果模型名称包含 "qwen_image_2512",自动使用 qwen_image_2512_dit category
# 否则根据 model_type_val 判断
if "qwen_image_2512" in actual_name_lower:
category = "qwen_image_2512_dit"
get_choices = get_qwen_image_2512_dit_choices
model_type_val = "Qwen-Image-2512" # 确保使用正确的模型类型
elif model_type_val == "Qwen-Image-2512":
category = "qwen_image_2512_dit"
get_choices = get_qwen_image_2512_dit_choices
else:
category = "qwen_image_dit"
get_choices = get_qwen_image_dit_choices
return download_model_handler(model_path_val, model_name, category, download_source_val, get_choices_func=get_choices, model_type_val=model_type_val, progress=progress)
def download_qwen_image_vae(model_path_val, model_name, download_source_val, progress=gr.Progress()):
model_type_val = model_type_input.value if hasattr(model_type_input, "value") else "Qwen-Image-Edit-2511"
category = "qwen_image_2512_vae" if model_type_val == "Qwen-Image-2512" else "qwen_image_vae"
get_choices = get_qwen_image_2512_vae_choices if model_type_val == "Qwen-Image-2512" else get_qwen_image_vae_choices
return download_model_handler(model_path_val, model_name, category, download_source_val, get_choices_func=get_choices, progress=progress)
def download_qwen_image_scheduler(model_path_val, model_name, download_source_val, progress=gr.Progress()):
model_type_val = model_type_input.value if hasattr(model_type_input, "value") else "Qwen-Image-Edit-2511"
category = "qwen_image_2512_scheduler" if model_type_val == "Qwen-Image-2512" else "qwen_image_scheduler"
get_choices = get_qwen_image_2512_scheduler_choices if model_type_val == "Qwen-Image-2512" else get_qwen_image_scheduler_choices
return download_model_handler(model_path_val, model_name, category, download_source_val, get_choices_func=get_choices, progress=progress)
def download_qwen25vl_encoder(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "qwen25vl_encoder", download_source_val, get_choices_func=get_qwen25vl_encoder_choices, progress=progress)
# 绑定状态更新和下载事件
def get_model_category(model_type_val, base_category):
if model_type_val == "Qwen-Image-2512":
return f"qwen_image_2512_{base_category}"
return f"qwen_image_{base_category}"
def update_dit_status(model_path_val, model_name, model_type_val):
category = get_model_category(model_type_val, "dit")
return update_model_status(model_path_val, model_name, category)
def update_vae_status(model_path_val, model_name, model_type_val):
category = get_model_category(model_type_val, "vae")
return update_model_status(model_path_val, model_name, category)
def update_scheduler_status(model_path_val, model_name, model_type_val):
category = get_model_category(model_type_val, "scheduler")
return update_model_status(model_path_val, model_name, category)
qwen_image_dit_path_input.change(
fn=lambda mp, mn, mt: update_dit_status(mp, mn, mt),
inputs=[model_path_input, qwen_image_dit_path_input, model_type_input],
outputs=[qwen_image_dit_download_btn],
)
qwen_image_vae_path_input.change(
fn=lambda mp, mn, mt: update_vae_status(mp, mn, mt),
inputs=[model_path_input, qwen_image_vae_path_input, model_type_input],
outputs=[qwen_image_vae_download_btn],
)
qwen_image_scheduler_path_input.change(
fn=lambda mp, mn, mt: update_scheduler_status(mp, mn, mt),
inputs=[model_path_input, qwen_image_scheduler_path_input, model_type_input],
outputs=[qwen_image_scheduler_download_btn],
)
qwen25vl_encoder_path_input.change(
fn=lambda mp, mn: update_model_status(mp, mn, "qwen25vl_encoder"),
inputs=[model_path_input, qwen25vl_encoder_path_input],
outputs=[qwen25vl_encoder_download_btn],
)
qwen_image_dit_download_btn.click(
fn=download_qwen_image_dit,
inputs=[model_path_input, qwen_image_dit_path_input, model_type_input, download_source_input],
outputs=[qwen_image_dit_download_status, qwen_image_dit_download_btn, qwen_image_dit_path_input],
)
qwen_image_vae_download_btn.click(
fn=download_qwen_image_vae,
inputs=[model_path_input, qwen_image_vae_path_input, download_source_input],
outputs=[qwen_image_vae_download_status, qwen_image_vae_download_btn, qwen_image_vae_path_input],
)
qwen_image_scheduler_download_btn.click(
fn=download_qwen_image_scheduler,
inputs=[model_path_input, qwen_image_scheduler_path_input, download_source_input],
outputs=[qwen_image_scheduler_download_status, qwen_image_scheduler_download_btn, qwen_image_scheduler_path_input],
)
qwen25vl_encoder_download_btn.click(
fn=download_qwen25vl_encoder,
inputs=[model_path_input, qwen25vl_encoder_path_input, download_source_input],
outputs=[qwen25vl_encoder_download_status, qwen25vl_encoder_download_btn, qwen25vl_encoder_path_input],
)
return {
"qwen_image_dit_path_input": qwen_image_dit_path_input,
"qwen_image_dit_download_btn": qwen_image_dit_download_btn,
"qwen_image_dit_download_status": qwen_image_dit_download_status,
"qwen_image_vae_path_input": qwen_image_vae_path_input,
"qwen_image_vae_download_btn": qwen_image_vae_download_btn,
"qwen_image_vae_download_status": qwen_image_vae_download_status,
"qwen_image_scheduler_path_input": qwen_image_scheduler_path_input,
"qwen_image_scheduler_download_btn": qwen_image_scheduler_download_btn,
"qwen_image_scheduler_download_status": qwen_image_scheduler_download_status,
"qwen25vl_encoder_path_input": qwen25vl_encoder_path_input,
"qwen25vl_encoder_download_btn": qwen25vl_encoder_download_btn,
"qwen25vl_encoder_download_status": qwen25vl_encoder_download_status,
"use_lora": use_lora,
"lora_path_input": lora_path_input,
"lora_strength": lora_strength,
"components_group": qwen_components_group,
}
def build_z_image_turbo_components(model_path, model_path_input, download_source_input, model_type_input, lang="zh"):
"""构建 Z-Image-Turbo 模型相关组件(用于图片页面)
必须在 Gradio 的 with 块内调用
Args:
model_path: 模型路径
model_path_input: 模型路径输入组件
download_source_input: 下载源输入组件
model_type_input: 模型类型输入组件(用于动态更新)
lang: 语言代码,默认为 "zh"
Returns:
dict: 包含所有 Z-Image-Turbo 相关组件的字典
"""
# 初始化选择
z_image_turbo_dit_choices_init = get_z_image_turbo_dit_choices(model_path)
z_image_turbo_vae_choices_init = get_z_image_turbo_vae_choices(model_path)
z_image_turbo_scheduler_choices_init = get_z_image_turbo_scheduler_choices(model_path)
qwen3_encoder_choices_init = get_qwen3_encoder_choices(model_path)
# Z-Image-Turbo 模型配置
with gr.Column(elem_classes=["diffusion-model-group"], visible=False) as z_image_turbo_components_group:
# Diffusion 模型
with gr.Row():
with gr.Column(scale=5):
z_image_turbo_dit_path_input = gr.Dropdown(
label=t("diffusion_model", lang),
choices=z_image_turbo_dit_choices_init,
value=z_image_turbo_dit_choices_init[0] if z_image_turbo_dit_choices_init else "",
allow_custom_value=True,
)
with gr.Column(scale=1, min_width=150):
# 计算初始按钮可见性
import os
initial_dit_btn_visible = False
if z_image_turbo_dit_choices_init:
initial_dit_value = z_image_turbo_dit_choices_init[0]
from utils.model_utils import check_model_exists, extract_model_name
actual_name = extract_model_name(initial_dit_value)
initial_dit_btn_visible = not check_model_exists(model_path, actual_name)
z_image_turbo_dit_download_btn = gr.Button(t("download", lang), visible=initial_dit_btn_visible, size="sm", variant="secondary")
z_image_turbo_dit_download_status = gr.Markdown("", visible=False)
# VAE 和 Scheduler
with gr.Row():
with gr.Column(scale=1):
z_image_turbo_vae_path_input = gr.Dropdown(
label=t("vae", lang),
choices=z_image_turbo_vae_choices_init,
value=z_image_turbo_vae_choices_init[0] if z_image_turbo_vae_choices_init else "",
allow_custom_value=True,
)
# 计算初始按钮可见性
initial_vae_btn_visible = False
if z_image_turbo_vae_choices_init:
initial_vae_value = z_image_turbo_vae_choices_init[0]
from utils.model_utils import check_model_exists, extract_model_name
actual_name = extract_model_name(initial_vae_value)
initial_vae_btn_visible = not check_model_exists(model_path, actual_name)
z_image_turbo_vae_download_btn = gr.Button(t("download", lang), visible=initial_vae_btn_visible, size="sm", variant="secondary")
z_image_turbo_vae_download_status = gr.Markdown("", visible=False)
with gr.Column(scale=1):
z_image_turbo_scheduler_path_input = gr.Dropdown(
label=t("scheduler", lang),
choices=z_image_turbo_scheduler_choices_init,
value=z_image_turbo_scheduler_choices_init[0] if z_image_turbo_scheduler_choices_init else "",
allow_custom_value=True,
)
# 计算初始按钮可见性
initial_scheduler_btn_visible = False
if z_image_turbo_scheduler_choices_init:
initial_scheduler_value = z_image_turbo_scheduler_choices_init[0]
from utils.model_utils import check_model_exists, extract_model_name
actual_name = extract_model_name(initial_scheduler_value)
initial_scheduler_btn_visible = not check_model_exists(model_path, actual_name)
z_image_turbo_scheduler_download_btn = gr.Button(t("download", lang), visible=initial_scheduler_btn_visible, size="sm", variant="secondary")
z_image_turbo_scheduler_download_status = gr.Markdown("", visible=False)
# Qwen3 编码器
with gr.Row():
with gr.Column(scale=1):
qwen3_encoder_path_input = gr.Dropdown(
label=t("qwen3_encoder", lang),
choices=qwen3_encoder_choices_init,
value=qwen3_encoder_choices_init[0] if qwen3_encoder_choices_init else "",
allow_custom_value=True,
)
# 计算初始按钮可见性:检查整个仓库目录是否存在
import os
initial_qwen3_btn_visible = False
if qwen3_encoder_choices_init:
repo_id = "JunHowie/Qwen3-4B-GPTQ-Int4"
repo_name = repo_id.split("/")[-1] # "Qwen3-4B-GPTQ-Int4"
repo_path = os.path.join(model_path, repo_name)
initial_qwen3_btn_visible = not (os.path.exists(repo_path) and os.path.isdir(repo_path))
qwen3_encoder_download_btn = gr.Button(t("download", lang), visible=initial_qwen3_btn_visible, size="sm", variant="secondary")
qwen3_encoder_download_status = gr.Markdown("", visible=False)
# LoRA 组件
lora_choices_init = get_lora_choices(model_path)
with gr.Row():
with gr.Column(scale=1):
use_lora = gr.Checkbox(
label=t("use_lora", lang),
value=False,
)
lora_path_input = gr.Dropdown(
label=t("lora", lang),
choices=lora_choices_init,
value=lora_choices_init[0] if lora_choices_init and lora_choices_init[0] else "",
allow_custom_value=True,
visible=False,
info=t("lora_info", lang),
)
lora_strength = gr.Slider(
label=t("lora_strength", lang),
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
visible=False,
info=t("lora_strength_info", lang),
)
# 绑定下载按钮事件
def download_z_image_turbo_dit(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "z_image_turbo_dit", download_source_val, get_choices_func=get_z_image_turbo_dit_choices, progress=progress)
def download_z_image_turbo_vae(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "z_image_turbo_vae", download_source_val, get_choices_func=get_z_image_turbo_vae_choices, progress=progress)
def download_z_image_turbo_scheduler(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "z_image_turbo_scheduler", download_source_val, get_choices_func=get_z_image_turbo_scheduler_choices, progress=progress)
def download_qwen3_encoder(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "qwen3_encoder", download_source_val, get_choices_func=get_qwen3_encoder_choices, progress=progress)
# 绑定状态更新和下载事件
def update_dit_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "z_image_turbo_dit")
def update_vae_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "z_image_turbo_vae")
def update_scheduler_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "z_image_turbo_scheduler")
def update_qwen3_encoder_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "qwen3_encoder")
z_image_turbo_dit_path_input.change(
fn=update_dit_status,
inputs=[model_path_input, z_image_turbo_dit_path_input],
outputs=[z_image_turbo_dit_download_btn],
)
z_image_turbo_vae_path_input.change(
fn=update_vae_status,
inputs=[model_path_input, z_image_turbo_vae_path_input],
outputs=[z_image_turbo_vae_download_btn],
)
z_image_turbo_scheduler_path_input.change(
fn=update_scheduler_status,
inputs=[model_path_input, z_image_turbo_scheduler_path_input],
outputs=[z_image_turbo_scheduler_download_btn],
)
# Qwen3 编码器状态更新:检查整个仓库目录是否存在
def update_qwen3_encoder_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "qwen3_encoder")
qwen3_encoder_path_input.change(
fn=update_qwen3_encoder_status,
inputs=[model_path_input, qwen3_encoder_path_input],
outputs=[qwen3_encoder_download_btn],
)
z_image_turbo_dit_download_btn.click(
fn=download_z_image_turbo_dit,
inputs=[model_path_input, z_image_turbo_dit_path_input, download_source_input],
outputs=[z_image_turbo_dit_download_status, z_image_turbo_dit_download_btn, z_image_turbo_dit_path_input],
)
z_image_turbo_vae_download_btn.click(
fn=download_z_image_turbo_vae,
inputs=[model_path_input, z_image_turbo_vae_path_input, download_source_input],
outputs=[z_image_turbo_vae_download_status, z_image_turbo_vae_download_btn, z_image_turbo_vae_path_input],
)
z_image_turbo_scheduler_download_btn.click(
fn=download_z_image_turbo_scheduler,
inputs=[model_path_input, z_image_turbo_scheduler_path_input, download_source_input],
outputs=[z_image_turbo_scheduler_download_status, z_image_turbo_scheduler_download_btn, z_image_turbo_scheduler_path_input],
)
qwen3_encoder_download_btn.click(
fn=download_qwen3_encoder,
inputs=[model_path_input, qwen3_encoder_path_input, download_source_input],
outputs=[qwen3_encoder_download_status, qwen3_encoder_download_btn, qwen3_encoder_path_input],
)
# LoRA 开关变化时显示/隐藏相关组件
def on_use_lora_change(use_lora_val):
return (
gr.update(visible=use_lora_val), # lora_path_input
gr.update(visible=use_lora_val), # lora_strength
)
use_lora.change(
fn=on_use_lora_change,
inputs=[use_lora],
outputs=[lora_path_input, lora_strength],
)
# 当 model_path 变化时更新 LoRA 选择
def update_lora_choices(model_path_val):
lora_choices = get_lora_choices(model_path_val)
return gr.update(choices=lora_choices, value=lora_choices[0] if lora_choices and lora_choices[0] else "")
model_path_input.change(
fn=update_lora_choices,
inputs=[model_path_input],
outputs=[lora_path_input],
)
return {
"z_image_turbo_dit_path_input": z_image_turbo_dit_path_input,
"z_image_turbo_dit_download_btn": z_image_turbo_dit_download_btn,
"z_image_turbo_dit_download_status": z_image_turbo_dit_download_status,
"z_image_turbo_vae_path_input": z_image_turbo_vae_path_input,
"z_image_turbo_vae_download_btn": z_image_turbo_vae_download_btn,
"z_image_turbo_vae_download_status": z_image_turbo_vae_download_status,
"z_image_turbo_scheduler_path_input": z_image_turbo_scheduler_path_input,
"z_image_turbo_scheduler_download_btn": z_image_turbo_scheduler_download_btn,
"z_image_turbo_scheduler_download_status": z_image_turbo_scheduler_download_status,
"qwen3_encoder_path_input": qwen3_encoder_path_input,
"qwen3_encoder_download_btn": qwen3_encoder_download_btn,
"qwen3_encoder_download_status": qwen3_encoder_download_status,
"use_lora": use_lora,
"lora_path_input": lora_path_input,
"lora_strength": lora_strength,
"components_group": z_image_turbo_components_group,
}
import os
import shutil
import warnings
import gradio as gr
from loguru import logger
from utils.model_utils import check_model_exists, extract_model_name, format_model_choice, is_distill_model_from_name
try:
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download as hf_snapshot_download
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
try:
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
MS_AVAILABLE = True
except ImportError:
MS_AVAILABLE = False
def download_model_from_hf(repo_id, model_name, model_path, progress=gr.Progress(), download_entire_repo=False):
"""从 Hugging Face 下载模型(支持文件和目录)
Args:
repo_id: 仓库 ID
model_name: 模型名称(文件或目录名)
model_path: 模型保存路径
progress: Gradio 进度条
download_entire_repo: 是否下载整个仓库(用于 Qwen3 编码器等)
"""
if not HF_AVAILABLE:
return f"❌ huggingface_hub 未安装,无法下载模型"
progress(0, desc=f"开始从 Hugging Face 下载 {model_name}...")
logger.info(f"开始从 Hugging Face {repo_id} 下载 {model_name}{model_path}")
target_path = os.path.join(model_path, model_name)
# 确保目标路径的父目录存在(如果 model_name 包含子目录路径,如 "Z-Image-Turbo/vae")
os.makedirs(os.path.dirname(target_path) if "/" in model_name else model_path, exist_ok=True)
os.makedirs(model_path, exist_ok=True)
# 如果指定下载整个仓库,直接下载整个仓库到目标目录
if download_entire_repo:
progress(0.1, desc=f"下载整个仓库 {repo_id}...")
logger.info(f"下载整个仓库 {repo_id}{target_path}")
if os.path.exists(target_path):
shutil.rmtree(target_path)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
hf_snapshot_download(
repo_id=repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
repo_type="model",
)
# 移动文件到正确位置
repo_name = repo_id.split("/")[-1]
source_dir = os.path.join(model_path, repo_name)
if os.path.exists(source_dir) and source_dir != target_path:
if os.path.exists(target_path):
shutil.rmtree(target_path)
shutil.move(source_dir, target_path)
logger.info(f"整个仓库 {repo_id} 下载完成,已移动到 {target_path}")
progress(1.0, desc=f"✅ {model_name} 下载完成")
return f"✅ {model_name} 下载完成"
# 判断是文件还是目录
is_directory = not (model_name.endswith(".safetensors") or model_name.endswith(".pth"))
if is_directory:
# 下载目录
progress(0.1, desc=f"下载目录 {model_name}...")
logger.info(f"检测到 {model_name} 是目录,使用 snapshot_download")
if os.path.exists(target_path):
shutil.rmtree(target_path)
# 从 model_name 中提取实际的目录名(如果包含子目录路径,如 "Z-Image-Turbo/scheduler" -> "scheduler")
actual_dir_name = model_name.split("/")[-1] if "/" in model_name else model_name
with warnings.catch_warnings():
warnings.simplefilter("ignore")
hf_snapshot_download(
repo_id=repo_id,
allow_patterns=[f"{actual_dir_name}/**"],
local_dir=model_path,
local_dir_use_symlinks=False,
repo_type="model",
)
# 移动文件到正确位置
repo_name = repo_id.split("/")[-1]
# 先尝试从 repo_name/actual_dir_name 查找
source_dir = os.path.join(model_path, repo_name, actual_dir_name)
if os.path.exists(source_dir):
# 确保目标路径的父目录存在
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.move(source_dir, target_path)
repo_dir = os.path.join(model_path, repo_name)
if os.path.exists(repo_dir) and not os.listdir(repo_dir):
os.rmdir(repo_dir)
else:
# 尝试从 model_path/actual_dir_name 查找
source_dir = os.path.join(model_path, actual_dir_name)
if os.path.exists(source_dir) and source_dir != target_path:
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.move(source_dir, target_path)
logger.info(f"目录 {model_name} 下载完成,已移动到 {target_path}")
else:
# 下载文件
progress(0.1, desc=f"下载文件 {model_name}...")
logger.info(f"检测到 {model_name} 是文件,使用 hf_hub_download")
if os.path.exists(target_path):
os.remove(target_path)
# 如果 model_name 包含子目录路径,提取实际文件名
actual_file_name = model_name.split("/")[-1] if "/" in model_name else model_name
with warnings.catch_warnings():
warnings.simplefilter("ignore")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=actual_file_name, # 使用实际文件名,不包含子目录路径
local_dir=model_path,
local_dir_use_symlinks=False,
repo_type="model",
)
# 如果 model_name 包含子目录路径,需要移动文件到子目录
if "/" in model_name:
target_dir = os.path.dirname(target_path)
os.makedirs(target_dir, exist_ok=True)
if os.path.exists(downloaded_path) and downloaded_path != target_path:
shutil.move(downloaded_path, target_path)
logger.info(f"文件 {model_name} 下载完成,保存到 {target_path}")
else:
logger.info(f"文件 {model_name} 下载完成,保存到 {downloaded_path}")
else:
logger.info(f"文件 {model_name} 下载完成,保存到 {downloaded_path}")
progress(1.0, desc=f"✅ {model_name} 下载完成")
return f"✅ {model_name} 下载完成"
def download_model_from_ms(repo_id, model_name, model_path, progress=gr.Progress(), download_entire_repo=False):
"""从 ModelScope 下载模型(支持文件和目录)
Args:
repo_id: 仓库 ID
model_name: 模型名称(文件或目录名)
model_path: 模型保存路径
progress: Gradio 进度条
download_entire_repo: 是否下载整个仓库(用于 Qwen3 编码器等)
"""
if not MS_AVAILABLE:
return f"❌ modelscope 未安装,无法下载模型"
progress(0, desc=f"开始从 ModelScope 下载 {model_name}...")
logger.info(f"开始从 ModelScope {repo_id} 下载 {model_name}{model_path}")
target_path = os.path.join(model_path, model_name)
# 确保目标路径的父目录存在(如果 model_name 包含子目录路径,如 "Z-Image-Turbo/vae")
os.makedirs(os.path.dirname(target_path) if "/" in model_name else model_path, exist_ok=True)
os.makedirs(model_path, exist_ok=True)
# 临时目录用于下载
temp_dir = os.path.join(model_path, f".temp_{model_name}")
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# 如果指定下载整个仓库,直接下载整个仓库到目标目录
if download_entire_repo:
progress(0.1, desc=f"下载整个仓库 {repo_id}...")
logger.info(f"下载整个仓库 {repo_id}{target_path}")
if os.path.exists(target_path):
shutil.rmtree(target_path)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
downloaded_path = ms_snapshot_download(
model_id=repo_id,
cache_dir=temp_dir,
)
# 移动文件到目标位置
if os.path.exists(downloaded_path):
if os.path.exists(target_path):
shutil.rmtree(target_path)
shutil.move(downloaded_path, target_path)
# 清理临时目录
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
logger.info(f"整个仓库 {repo_id} 下载完成,已移动到 {target_path}")
progress(1.0, desc=f"✅ {model_name} 下载完成")
return f"✅ {model_name} 下载完成"
# 判断是文件还是目录
is_directory = not (model_name.endswith(".safetensors") or model_name.endswith(".pth"))
is_file = not is_directory
# 处理目录下载
if is_directory:
progress(0.1, desc=f"下载目录 {model_name}...")
logger.info(f"检测到 {model_name} 是目录,使用 snapshot_download")
if os.path.exists(target_path):
shutil.rmtree(target_path)
# 从 model_name 中提取实际的目录名(如果包含子目录路径,如 "Z-Image-Turbo/scheduler" -> "scheduler")
actual_dir_name = model_name.split("/")[-1] if "/" in model_name else model_name
# 使用 snapshot_download 下载目录(使用实际的目录名)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
downloaded_path = ms_snapshot_download(
model_id=repo_id,
cache_dir=temp_dir,
allow_patterns=[f"{actual_dir_name}/**"],
)
# 移动文件到目标位置
# ModelScope 下载后,文件在 downloaded_path/actual_dir_name
source_dir = os.path.join(downloaded_path, actual_dir_name)
if not os.path.exists(source_dir) and os.path.exists(downloaded_path):
# 如果找不到,尝试从下载路径中查找
for item in os.listdir(downloaded_path):
item_path = os.path.join(downloaded_path, item)
if actual_dir_name.lower() in item.lower() or (os.path.isdir(item_path) and item.lower() == actual_dir_name.lower()):
source_dir = item_path
break
if os.path.exists(source_dir):
# 确保目标路径的父目录存在
os.makedirs(os.path.dirname(target_path), exist_ok=True)
if os.path.exists(target_path):
shutil.rmtree(target_path)
shutil.move(source_dir, target_path)
logger.info(f"目录 {model_name} 下载完成,已移动到 {target_path}")
else:
logger.error(f"无法找到下载的目录:{source_dir},下载路径:{downloaded_path}")
# 清理临时目录
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
logger.info(f"目录 {model_name} 下载完成,保存到 {target_path}")
# 处理文件下载
elif is_file:
progress(0.1, desc=f"下载文件 {model_name}...")
logger.info(f"检测到 {model_name} 是文件,使用 snapshot_download")
if os.path.exists(target_path):
os.remove(target_path)
os.makedirs(os.path.dirname(target_path) if "/" in model_name else model_path, exist_ok=True)
# 如果 model_name 包含子目录路径,提取实际文件名用于 allow_patterns
actual_file_name = model_name.split("/")[-1] if "/" in model_name else model_name
# 使用 snapshot_download 下载文件
with warnings.catch_warnings():
warnings.simplefilter("ignore")
downloaded_path = ms_snapshot_download(
model_id=repo_id,
cache_dir=temp_dir,
allow_patterns=[actual_file_name], # 只使用文件名,不包含子目录路径
)
# 查找并移动文件
# 如果 model_name 包含子目录路径(如 "Qwen-Image-2512/file.safetensors"),提取文件名
actual_file_name = model_name.split("/")[-1] if "/" in model_name else model_name
# 先尝试完整路径
source_file = os.path.join(downloaded_path, model_name)
if not os.path.exists(source_file):
# 尝试直接使用文件名
source_file = os.path.join(downloaded_path, actual_file_name)
if not os.path.exists(source_file):
# 如果找不到,从下载路径中递归查找
for root, dirs, files_list in os.walk(downloaded_path):
if actual_file_name in files_list:
source_file = os.path.join(root, actual_file_name)
break
if os.path.exists(source_file):
# 确保目标目录存在
target_dir = os.path.dirname(target_path)
if target_dir and not os.path.exists(target_dir):
os.makedirs(target_dir, exist_ok=True)
shutil.move(source_file, target_path)
logger.info(f"文件 {model_name} 下载完成,保存到 {target_path}")
else:
logger.error(f"❌ 下载失败:无法找到文件 {actual_file_name}{downloaded_path}")
return f"❌ 下载失败:无法找到文件 {actual_file_name}"
# 清理临时目录
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
else:
return f"❌ 无法找到 {model_name}:既不是文件也不是目录"
progress(1.0, desc=f"✅ {model_name} 下载完成")
return f"✅ {model_name} 下载完成"
def download_model(repo_id, model_name, model_path, download_source="huggingface", progress=gr.Progress(), download_entire_repo=False):
"""统一的下载函数,根据下载源选择 Hugging Face 或 ModelScope
Args:
repo_id: 仓库 ID
model_name: 模型名称(文件或目录名)
model_path: 模型保存路径
download_source: 下载源 ("huggingface" 或 "modelscope")
progress: Gradio 进度条
download_entire_repo: 是否下载整个仓库(用于 Qwen3 编码器等)
"""
if download_source == "modelscope":
return download_model_from_ms(repo_id, model_name, model_path, progress, download_entire_repo)
else:
return download_model_from_hf(repo_id, model_name, model_path, progress, download_entire_repo)
def get_repo_id_for_model(model_type, is_distill, model_category="dit"):
"""根据模型类型、是否 distill 和模型类别获取对应的 Hugging Face 仓库 ID"""
if model_category == "dit":
if model_type == "wan2.1":
return "lightx2v/wan2.1-Distill-Models" if is_distill else "lightx2v/wan2.1-Official-Models"
elif model_type == "Qwen-Image-Edit-2511":
return "lightx2v/Qwen-Image-Edit-2511-Lightning"
else: # wan2.2
return "lightx2v/wan2.2-Distill-Models" if is_distill else "lightx2v/wan2.2-Official-Models"
elif model_category == "high_noise" or model_category == "low_noise":
if is_distill:
return "lightx2v/wan2.2-Distill-Models"
else:
return "lightx2v/wan2.2-Official-Models"
elif model_category == "t5" or model_category == "clip":
return "lightx2v/Encoders"
elif model_category == "qwen25vl":
return "lightx2v/Encoders"
elif model_category == "vae":
return "lightx2v/Autoencoders"
elif model_category == "qwen_image_vae" or model_category == "qwen_image_scheduler":
return "Qwen/Qwen-Image-Edit-2511"
return None
MODEL_CONFIG_MAP = {
"dit": {"get_repo_id": lambda model_type, is_distill, _: get_repo_id_for_model(model_type, is_distill, "dit")},
"high_noise": {"get_repo_id": lambda model_type, is_distill, _: get_repo_id_for_model("wan2.2", is_distill, "high_noise")},
"low_noise": {"get_repo_id": lambda model_type, is_distill, _: get_repo_id_for_model("wan2.2", is_distill, "low_noise")},
"t5": {"get_repo_id": lambda _, __, ___: get_repo_id_for_model(None, None, "t5")},
"clip": {"get_repo_id": lambda _, __, ___: get_repo_id_for_model(None, None, "clip")},
"vae": {"get_repo_id": lambda _, __, ___: get_repo_id_for_model(None, None, "vae")},
"vae_encoder": {"get_repo_id": lambda _, __, ___: get_repo_id_for_model(None, None, "vae")}, # 兼容旧代码
"vae_decoder": {"get_repo_id": lambda _, __, ___: get_repo_id_for_model(None, None, "vae")}, # 兼容旧代码
"qwen_image_dit": {"get_repo_id": lambda _, __, ___: "lightx2v/Qwen-Image-Edit-2511-Lightning"},
"qwen_image_vae": {"get_repo_id": lambda _, __, ___: "Qwen/Qwen-Image-Edit-2511"},
"qwen_image_scheduler": {"get_repo_id": lambda _, __, ___: "Qwen/Qwen-Image-Edit-2511"},
"qwen_image_2512_dit": {"get_repo_id": lambda _, __, ___: "lightx2v/Qwen-Image-2512-Lightning"},
"qwen_image_2512_vae": {"get_repo_id": lambda _, __, ___: "Qwen/Qwen-Image-2512"},
"qwen_image_2512_scheduler": {"get_repo_id": lambda _, __, ___: "Qwen/Qwen-Image-2512"},
"z_image_turbo_dit": {"get_repo_id": lambda _, __, ___: "lightx2v/Z-Image-Turbo-Quantized"},
"z_image_turbo_vae": {"get_repo_id": lambda _, __, ___: "Tongyi-MAI/Z-Image-Turbo"},
"z_image_turbo_scheduler": {"get_repo_id": lambda _, __, ___: "Tongyi-MAI/Z-Image-Turbo"},
"qwen3_encoder": {"get_repo_id": lambda _, __, ___: "JunHowie/Qwen3-4B-GPTQ-Int4"},
"qwen25vl_encoder": {"get_repo_id": lambda _, __, ___: "lightx2v/Encoders"},
}
# Tokenizer 配置
TOKENIZER_CONFIG = {
"t5": {"name": "google", "repo_id": get_repo_id_for_model(None, None, "t5")},
"clip": {"name": "xlm-roberta-large", "repo_id": get_repo_id_for_model(None, None, "clip")},
}
def update_model_status(model_path_val, model_name, model_category, model_type_val=None):
"""通用的模型状态更新函数
Args:
model_path_val: 模型路径
model_name: 模型名称(可能包含状态标识)
model_category: 模型类别 (dit, high_noise, low_noise, t5, clip, vae, qwen_image_dit, etc.)
model_type_val: 模型类型 (wan2.1, wan2.2, Qwen-Image-Edit-2511),某些类别需要此参数
Returns:
gr.update: 下载按钮的可见性更新
"""
if not model_name:
return gr.update(visible=False)
actual_name = extract_model_name(model_name)
# 对于 Qwen3 编码器,检查整个仓库目录是否存在
if model_category == "qwen3_encoder":
# Qwen3 编码器检查整个仓库目录 Qwen3-4B-GPTQ-Int4 是否存在
repo_id = "JunHowie/Qwen3-4B-GPTQ-Int4"
repo_name = repo_id.split("/")[-1] # "Qwen3-4B-GPTQ-Int4"
exists = check_model_exists(model_path_val, repo_name)
return gr.update(visible=not exists)
exists = check_model_exists(model_path_val, actual_name)
return gr.update(visible=not exists)
def update_tokenizer_status(model_path_val, tokenizer_type):
"""更新 Tokenizer 状态
Args:
model_path_val: 模型路径
tokenizer_type: tokenizer 类型 ("t5" 或 "clip")
Returns:
tuple: (dropdown_update, btn_visible_update)
"""
config = TOKENIZER_CONFIG.get(tokenizer_type)
if not config:
return gr.update(), gr.update(visible=False)
tokenizer_name = config["name"]
exists = check_model_exists(model_path_val, tokenizer_name)
if exists:
status_text = f"{tokenizer_name} ✅"
return gr.update(value=status_text), gr.update(visible=False)
else:
status_text = f"{tokenizer_name} ❌"
return gr.update(value=status_text), gr.update(visible=True)
def download_model_handler(model_path_val, model_name, model_category, download_source_val, get_choices_func=None, model_type_val=None, task_type_val=None, progress=gr.Progress()):
"""通用的模型下载处理函数
Args:
model_path_val: 模型路径
model_name: 模型名称(可能包含状态标识)
model_category: 模型类别
download_source_val: 下载源 ("huggingface" 或 "modelscope")
get_choices_func: 获取模型选项列表的函数(可选)
model_type_val: 模型类型(某些类别需要)
task_type_val: 任务类型(某些类别需要)
progress: Gradio 进度条
Returns:
tuple: (status_update, btn_visible_update, choices_update)
"""
if not model_name:
return gr.update(value="请先选择模型"), gr.update(visible=False), gr.update()
actual_name = extract_model_name(model_name)
# 获取 repo_id
config = MODEL_CONFIG_MAP.get(model_category)
if not config:
return gr.update(value=f"未知的模型类别: {model_category}"), gr.update(visible=False), gr.update()
is_distill = False
if model_category in ["dit", "high_noise", "low_noise"]:
is_distill = is_distill_model_from_name(actual_name)
repo_id = config["get_repo_id"](model_type_val, is_distill, model_category)
# 对于 Qwen3 编码器,始终下载整个仓库目录
if model_category == "qwen3_encoder":
# Qwen3 编码器下载整个仓库,使用仓库名作为目录名
repo_name = repo_id.split("/")[-1] # "Qwen3-4B-GPTQ-Int4"
actual_name = repo_name
# 对于 VAE、Scheduler 和 Diffusion 模型,需要下载到子目录
elif model_category in [
"qwen_image_vae",
"qwen_image_scheduler",
"qwen_image_dit",
"qwen_image_2512_vae",
"qwen_image_2512_scheduler",
"qwen_image_2512_dit",
"z_image_turbo_vae",
"z_image_turbo_scheduler",
"z_image_turbo_dit",
]:
# 确定模型子目录名称
if model_category in ["qwen_image_vae", "qwen_image_scheduler", "qwen_image_dit"]:
model_subdir = "Qwen-Image-Edit-2511"
elif model_category in ["qwen_image_2512_vae", "qwen_image_2512_scheduler", "qwen_image_2512_dit"]:
model_subdir = "Qwen-Image-2512"
elif model_category in ["z_image_turbo_vae", "z_image_turbo_scheduler", "z_image_turbo_dit"]:
model_subdir = "Z-Image-Turbo"
# 如果需要在子目录下载,修改 actual_name 为子目录路径
# 如果 actual_name 已经包含子目录路径,则提取文件名
if "/" in actual_name:
# 提取文件名(去掉子目录前缀)
parts = actual_name.split("/")
if len(parts) > 1:
actual_name = parts[-1]
# 添加子目录路径前缀
actual_name = f"{model_subdir}/{actual_name}"
# 下载模型
# 对于 Qwen3 编码器,下载整个仓库
download_entire_repo = model_category == "qwen3_encoder"
result = download_model(repo_id, actual_name, model_path_val, download_source_val, progress, download_entire_repo)
# 下载完成后,直接标记为已存在(使用 ✅ 状态),避免文件系统同步延迟导致的状态检查失败
formatted_name_with_status = format_model_choice(actual_name, model_path_val, status_emoji="✅")
# 更新状态(下载完成后,模型应该存在,所以隐藏下载按钮)
btn_visible = gr.update(visible=False)
# 更新选项列表(如果提供了获取选项的函数)
choices_update = gr.update()
if get_choices_func:
try:
if model_category in ["dit"] and model_type_val and task_type_val:
choices = get_choices_func(model_path_val, model_type_val, task_type_val)
elif model_category in ["high_noise", "low_noise"] and task_type_val:
choices = get_choices_func(model_path_val, "wan2.2", task_type_val)
else:
choices = get_choices_func(model_path_val)
# 使用带状态标识的格式化名称
choices_update = gr.update(choices=choices, value=formatted_name_with_status)
except Exception as e:
# 如果获取选项失败,只更新值
choices_update = gr.update(value=formatted_name_with_status)
return gr.update(value=result), btn_visible, choices_update
def download_tokenizer_handler(model_path_val, tokenizer_type, download_source_val, progress=gr.Progress()):
"""下载 Tokenizer 处理函数
Args:
model_path_val: 模型路径
tokenizer_type: tokenizer 类型 ("t5" 或 "clip")
download_source_val: 下载源
progress: Gradio 进度条
Returns:
tuple: (status_update, dropdown_update, btn_visible_update)
"""
config = TOKENIZER_CONFIG.get(tokenizer_type)
if not config:
return gr.update(value=f"未知的 tokenizer 类型: {tokenizer_type}"), gr.update(), gr.update(visible=False)
tokenizer_name = config["name"]
repo_id = config["repo_id"]
result = download_model(repo_id, tokenizer_name, model_path_val, download_source_val, progress)
dropdown_update, btn_visible = update_tokenizer_status(model_path_val, tokenizer_type)
return gr.update(value=result), dropdown_update, btn_visible
def create_update_status_wrappers():
"""创建所有模型状态更新的包装函数
Returns:
dict: 包含所有 update 函数的字典
"""
def update_dit_status(model_path_val, model_name, model_type_val):
return update_model_status(model_path_val, model_name, "dit", model_type_val)
def update_t5_model_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "t5")
def update_t5_tokenizer_status(model_path_val):
return update_tokenizer_status(model_path_val, "t5")
def update_clip_model_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "clip")
def update_clip_tokenizer_status(model_path_val):
return update_tokenizer_status(model_path_val, "clip")
def update_vae_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "vae")
def update_vae_encoder_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "vae") # 兼容旧代码,统一使用 vae
def update_vae_decoder_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "vae") # 兼容旧代码,统一使用 vae
def update_high_noise_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "high_noise", "wan2.2")
def update_low_noise_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "low_noise", "wan2.2")
def update_qwen_image_dit_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "qwen_image_dit")
def update_qwen_image_vae_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "qwen_image_vae")
def update_qwen_image_scheduler_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "qwen_image_scheduler")
def update_qwen25vl_encoder_status(model_path_val, model_name):
return update_model_status(model_path_val, model_name, "qwen25vl_encoder")
return {
"update_dit_status": update_dit_status,
"update_t5_model_status": update_t5_model_status,
"update_t5_tokenizer_status": update_t5_tokenizer_status,
"update_clip_model_status": update_clip_model_status,
"update_clip_tokenizer_status": update_clip_tokenizer_status,
"update_vae_status": update_vae_status,
"update_vae_encoder_status": update_vae_encoder_status, # 兼容旧代码
"update_vae_decoder_status": update_vae_decoder_status, # 兼容旧代码
"update_high_noise_status": update_high_noise_status,
"update_low_noise_status": update_low_noise_status,
"update_qwen_image_dit_status": update_qwen_image_dit_status,
"update_qwen_image_vae_status": update_qwen_image_vae_status,
"update_qwen_image_scheduler_status": update_qwen_image_scheduler_status,
"update_qwen25vl_encoder_status": update_qwen25vl_encoder_status,
}
def create_download_wrappers(get_choices_funcs):
"""创建所有模型下载的包装函数
Args:
get_choices_funcs: 包含所有 get_choices 函数的字典
Returns:
dict: 包含所有 download 函数的字典
"""
def download_dit_model(model_path_val, model_name, model_type_val, task_type_val, download_source_val, progress=gr.Progress()):
return download_model_handler(
model_path_val,
model_name,
"dit",
download_source_val,
get_choices_func=get_choices_funcs.get("get_dit_choices"),
model_type_val=model_type_val,
task_type_val=task_type_val,
progress=progress,
)
def download_t5_model(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "t5", download_source_val, get_choices_func=get_choices_funcs.get("get_t5_model_choices"), progress=progress)
def download_t5_tokenizer(model_path_val, download_source_val, progress=gr.Progress()):
return download_tokenizer_handler(model_path_val, "t5", download_source_val, progress)
def download_clip_model(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "clip", download_source_val, get_choices_func=get_choices_funcs.get("get_clip_model_choices"), progress=progress)
def download_clip_tokenizer(model_path_val, download_source_val, progress=gr.Progress()):
return download_tokenizer_handler(model_path_val, "clip", download_source_val, progress)
def download_vae(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "vae", download_source_val, get_choices_func=get_choices_funcs.get("get_vae_choices"), progress=progress)
def download_vae_encoder(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_vae(model_path_val, model_name, download_source_val, progress) # 兼容旧代码
def download_vae_decoder(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_vae(model_path_val, model_name, download_source_val, progress) # 兼容旧代码
def download_high_noise_model(model_path_val, model_name, task_type_val, download_source_val, progress=gr.Progress()):
return download_model_handler(
model_path_val,
model_name,
"high_noise",
download_source_val,
get_choices_func=get_choices_funcs.get("get_high_noise_choices"),
model_type_val="wan2.2",
task_type_val=task_type_val,
progress=progress,
)
def download_low_noise_model(model_path_val, model_name, task_type_val, download_source_val, progress=gr.Progress()):
return download_model_handler(
model_path_val,
model_name,
"low_noise",
download_source_val,
get_choices_func=get_choices_funcs.get("get_low_noise_choices"),
model_type_val="wan2.2",
task_type_val=task_type_val,
progress=progress,
)
def download_qwen_image_dit(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "qwen_image_dit", download_source_val, get_choices_func=get_choices_funcs.get("get_qwen_image_dit_choices"), progress=progress)
def download_qwen_image_vae(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "qwen_image_vae", download_source_val, get_choices_func=get_choices_funcs.get("get_qwen_image_vae_choices"), progress=progress)
def download_qwen_image_scheduler(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(
model_path_val, model_name, "qwen_image_scheduler", download_source_val, get_choices_func=get_choices_funcs.get("get_qwen_image_scheduler_choices"), progress=progress
)
def download_qwen25vl_encoder(model_path_val, model_name, download_source_val, progress=gr.Progress()):
return download_model_handler(model_path_val, model_name, "qwen25vl_encoder", download_source_val, get_choices_func=get_choices_funcs.get("get_qwen25vl_encoder_choices"), progress=progress)
return {
"download_dit_model": download_dit_model,
"download_t5_model": download_t5_model,
"download_t5_tokenizer": download_t5_tokenizer,
"download_clip_model": download_clip_model,
"download_clip_tokenizer": download_clip_tokenizer,
"download_vae": download_vae,
"download_vae_encoder": download_vae_encoder, # 兼容旧代码
"download_vae_decoder": download_vae_decoder, # 兼容旧代码
"download_high_noise_model": download_high_noise_model,
"download_low_noise_model": download_low_noise_model,
"download_qwen_image_dit": download_qwen_image_dit,
"download_qwen_image_vae": download_qwen_image_vae,
"download_qwen_image_scheduler": download_qwen_image_scheduler,
"download_qwen25vl_encoder": download_qwen25vl_encoder,
}
import concurrent.futures
import glob
import os
from loguru import logger
try:
from huggingface_hub import HfApi, list_repo_files
HF_AVAILABLE = True
except ImportError:
HfApi = None
list_repo_files = None
HF_AVAILABLE = False
try:
from modelscope.hub.api import HubApi
MS_AVAILABLE = True
except ImportError:
HubApi = None
MS_AVAILABLE = False
import gc
import importlib.util
import re
import psutil
import torch
from loguru import logger
def is_module_installed(module_name):
"""检查模块是否已安装"""
spec = importlib.util.find_spec(module_name)
return spec is not None
def get_available_quant_ops():
"""获取可用的量化算子"""
available_ops = []
triton_installed = is_module_installed("triton")
if triton_installed:
available_ops.append(("triton", True))
else:
available_ops.append(("triton", False))
vllm_installed = is_module_installed("vllm")
if vllm_installed:
available_ops.append(("vllm", True))
else:
available_ops.append(("vllm", False))
sgl_installed = is_module_installed("sgl_kernel")
if sgl_installed:
available_ops.append(("sgl", True))
else:
available_ops.append(("sgl", False))
q8f_installed = is_module_installed("q8_kernels")
if q8f_installed:
available_ops.append(("q8f", True))
else:
available_ops.append(("q8f", False))
# 检测 torch 选项:需要同时满足 hasattr(torch, "_scaled_mm") 和安装了 torchao
torch_available = hasattr(torch, "_scaled_mm") and is_module_installed("torchao")
if torch_available:
available_ops.append(("torch", True))
else:
available_ops.append(("torch", False))
return available_ops
def get_available_attn_ops():
"""获取可用的注意力算子"""
available_ops = []
vllm_installed = is_module_installed("flash_attn")
if vllm_installed:
available_ops.append(("flash_attn2", True))
else:
available_ops.append(("flash_attn2", False))
sgl_installed = is_module_installed("flash_attn_interface")
if sgl_installed:
available_ops.append(("flash_attn3", True))
else:
available_ops.append(("flash_attn3", False))
sage_installed = is_module_installed("sageattention")
if sage_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
else:
available_ops.append(("torch_sdpa", False))
return available_ops
def get_gpu_memory(gpu_idx=0):
"""获取GPU显存(GB)"""
if not torch.cuda.is_available():
return 0
with torch.cuda.device(gpu_idx):
memory_info = torch.cuda.mem_get_info()
total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
return total_memory
def get_cpu_memory():
"""获取CPU可用内存(GB)"""
available_bytes = psutil.virtual_memory().available
return available_bytes / 1024**3
def cleanup_memory():
"""清理内存"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def is_fp8_supported_gpu():
"""检查GPU是否支持fp8"""
if not torch.cuda.is_available():
return False
compute_capability = torch.cuda.get_device_capability(0)
major, minor = compute_capability
return (major == 8 and minor == 9) or (major >= 9)
def is_sm_greater_than_90():
"""检测计算能力是否大于 (9,0)"""
if not torch.cuda.is_available():
return False
compute_capability = torch.cuda.get_device_capability(0)
major, minor = compute_capability
return (major, minor) > (9, 0)
def get_gpu_generation():
"""检测GPU系列,返回 '40' 表示40系,'30' 表示30系,None 表示其他"""
if not torch.cuda.is_available():
return None
try:
gpu_name = torch.cuda.get_device_name(0)
gpu_name_lower = gpu_name.lower()
# 检测40系显卡 (RTX 40xx, RTX 4060, RTX 4070, RTX 4080, RTX 4090等)
if any(keyword in gpu_name_lower for keyword in ["rtx 40", "rtx40", "geforce rtx 40"]):
# 进一步检查是40xx系列
match = re.search(r"rtx\s*40\d+|40\d+", gpu_name_lower)
if match:
return "40"
# 检测30系显卡 (RTX 30xx, RTX 3060, RTX 3070, RTX 3080, RTX 3090等)
if any(keyword in gpu_name_lower for keyword in ["rtx 30", "rtx30", "geforce rtx 30"]):
# 进一步检查是30xx系列
match = re.search(r"rtx\s*30\d+|30\d+", gpu_name_lower)
if match:
return "30"
return None
except Exception as e:
logger.warning(f"无法检测GPU系列: {e}")
return None
# 模型列表缓存(避免每次从 HF 获取)
HF_MODELS_CACHE = {
"lightx2v/wan2.1-Distill-Models": [],
"lightx2v/wan2.1-Official-Models": [],
"lightx2v/wan2.2-Distill-Models": [],
"lightx2v/wan2.2-Official-Models": [],
"lightx2v/Encoders": [],
"lightx2v/Autoencoders": [],
"lightx2v/Qwen-Image-Edit-2511-Lightning": [],
"Qwen/Qwen-Image-Edit-2511": [],
"lightx2v/Qwen-Image-2512-Lightning": [],
"Qwen/Qwen-Image-2512": [],
"lightx2v/Z-Image-Turbo-Quantized": [],
"Tongyi-MAI/Z-Image-Turbo": [],
"JunHowie/Qwen3-4B-GPTQ-Int4": [],
}
def process_files(files, repo_id=None):
"""处理文件列表,提取模型名称"""
model_names = []
seen_dirs = set()
# 对于 Qwen/Qwen-Image-Edit-2511、Qwen/Qwen-Image-2512 和 Tongyi-MAI/Z-Image-Turbo 仓库,保留 vae 和 scheduler 目录
is_qwen_image_repo = repo_id in ["Qwen/Qwen-Image-Edit-2511", "Qwen/Qwen-Image-2512", "Tongyi-MAI/Z-Image-Turbo"]
# 对于 Qwen3 编码器仓库,整个仓库就是一个模型目录
is_qwen3_encoder_repo = repo_id == "JunHowie/Qwen3-4B-GPTQ-Int4"
for file in files:
# 排除包含comfyui的文件
if "comfyui" in file.lower():
continue
# 如果是顶层文件(不包含路径分隔符)
if "/" not in file:
# 对于 Qwen3 编码器仓库,不添加单个文件,只添加目录
# 因为 Qwen3 编码器应该下载整个仓库目录
if not is_qwen3_encoder_repo and file.endswith(".safetensors"):
model_names.append(file)
else:
# 提取顶层目录名(支持_split目录)
top_dir = file.split("/")[0]
if top_dir not in seen_dirs:
seen_dirs.add(top_dir)
# 对于 Qwen 仓库,保留 vae 和 scheduler 目录
if is_qwen_image_repo and top_dir.lower() in ["vae", "scheduler"]:
model_names.append(top_dir)
# 对于 Qwen3 编码器仓库,保留所有顶层目录(排除 comfyui)
elif is_qwen3_encoder_repo:
model_names.append(top_dir)
# 支持safetensors文件目录和_split分block存储目录
elif "_split" in top_dir or any(f.startswith(f"{top_dir}/") and f.endswith(".safetensors") for f in files):
model_names.append(top_dir)
return sorted(set(model_names))
def load_hf_models_cache():
"""从 Hugging Face 加载模型列表并缓存,如果 HF 超时或失败,则尝试使用 ModelScope"""
# 超时时间(秒)
HF_TIMEOUT = 30
for repo_id in HF_MODELS_CACHE.keys():
files = None
source = None
# 首先尝试从 ModelScope 获取
try:
if MS_AVAILABLE:
logger.info(f"Loading models from ModelScope {repo_id}...")
api = HubApi()
# ModelScope API 获取文件列表
model_files = api.get_model_files(model_id=repo_id, recursive=True)
# 提取文件路径
files = [file["Path"] for file in model_files if file.get("Type") == "blob"]
source = "ModelScope"
logger.info(f"Successfully loaded models from ModelScope {repo_id}")
except: # noqa E722
# 如果 ModelScope 失败,尝试从 Hugging Face 获取(带超时)
if files is None and HF_AVAILABLE:
logger.info(f"Loading models from Hugging Face {repo_id}...")
api = HfApi()
# 使用线程池执行器设置超时
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(list_repo_files, repo_id=repo_id, repo_type="model")
files = future.result(timeout=HF_TIMEOUT)
source = "Hugging Face"
# 处理文件列表
if files:
model_names = process_files(files, repo_id)
HF_MODELS_CACHE[repo_id] = model_names
logger.info(f"Loaded {len(HF_MODELS_CACHE[repo_id])} models from {source} {repo_id}")
else:
logger.warning(f"No files retrieved from {repo_id}, setting empty cache")
HF_MODELS_CACHE[repo_id] = []
def get_hf_models(repo_id, prefix_filter=None, keyword_filter=None):
"""从缓存的模型列表中获取模型(不再实时从 HF 获取)"""
if repo_id not in HF_MODELS_CACHE:
return []
models = HF_MODELS_CACHE[repo_id]
if prefix_filter:
models = [m for m in models if m.lower().startswith(prefix_filter.lower())]
if keyword_filter:
models = [m for m in models if keyword_filter.lower() in m.lower()]
return models
def extract_op_name(op_str):
"""从格式化的操作符名称中提取原始名称"""
if not op_str:
return ""
op_str = op_str.strip()
if op_str.startswith("✅"):
op_str = op_str[1:].strip()
elif op_str.startswith("❌"):
op_str = op_str[1:].strip()
if "(" in op_str:
op_str = op_str.split("(")[0].strip()
return op_str
def scan_model_path_contents(model_path):
"""扫描 model_path 目录,返回可用的文件和子目录"""
if not model_path or not os.path.exists(model_path):
return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
dirs = []
files = []
safetensors_dirs = []
pth_files = []
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
return {
"dirs": sorted(dirs),
"files": sorted(files),
"safetensors_dirs": sorted(safetensors_dirs),
"pth_files": sorted(pth_files),
}
def check_model_exists(model_path, model_name):
"""检查模型是否已下载
支持子目录路径,如 Qwen-Image-2512/qwen_image_2512_fp8_e4m3fn_scaled.safetensors
"""
if not model_path or not os.path.exists(model_path):
return False
# 处理包含子目录路径的模型名(如 Qwen-Image-2512/qwen_image_2512_fp8_e4m3fn_scaled.safetensors)
model_path_full = os.path.join(model_path, model_name)
# 检查是否存在(文件或目录)
if os.path.exists(model_path_full):
return True
# 额外检查:如果是 safetensors 文件,也检查同名目录(_split 目录)
if model_name.endswith(".safetensors"):
# 如果模型名包含子目录路径,提取文件名
if "/" in model_name:
base_name_with_subdir = model_name.replace(".safetensors", "")
# 检查子目录下的 _split 目录
split_dir = os.path.join(model_path, base_name_with_subdir + "_split")
if os.path.exists(split_dir):
return True
else:
# 检查同名目录(可能是分块存储)
base_name = model_name.replace(".safetensors", "")
split_dir = os.path.join(model_path, base_name + "_split")
if os.path.exists(split_dir):
return True
return False
def format_model_choice(model_name, model_path, status_emoji=None):
"""格式化模型选项,添加下载状态标识"""
if not model_name:
return ""
# 如果提供了状态 emoji,直接使用
if status_emoji is not None:
return f"{status_emoji} {model_name}"
# 否则检查本地是否存在
exists = check_model_exists(model_path, model_name)
emoji = "✅" if exists else "❌"
return f"{emoji} {model_name}"
def extract_model_name(formatted_name):
"""从格式化的选项名称中提取原始模型名称"""
if not formatted_name:
return ""
# 移除开头的 emoji 和空格
if formatted_name.startswith("✅ ") or formatted_name.startswith("❌ "):
return formatted_name[2:].strip()
return formatted_name.strip()
def sort_model_choices(models):
"""根据设备能力和模型类型对模型列表排序
排序规则:
- 如果设备支持 fp8:fp8+split > int8+split > fp8 > int8 > 其他
- 如果设备不支持 fp8:int8+split > int8 > 其他
"""
# 延迟导入避免循环依赖
from utils.model_utils import is_fp8_supported_gpu
fp8_supported = is_fp8_supported_gpu()
def get_priority(name):
name_lower = name.lower()
if fp8_supported:
# fp8 设备:fp8+split > int8+split > fp8 > int8 > 其他
if "fp8" in name_lower and "_split" in name_lower:
return 0 # 最高优先级
elif "int8" in name_lower and "_split" in name_lower:
return 1
elif "fp8" in name_lower:
return 2
elif "int8" in name_lower:
return 3
else:
return 4 # 其他
else:
# 非 fp8 设备:优先 int8+split,其次 int8
if "int8" in name_lower and "_split" in name_lower:
return 0 # 最高优先级
elif "int8" in name_lower:
return 1
else:
return 2 # 其他(fp8 已经被过滤掉了)
return sorted(models, key=lambda x: (get_priority(x), x.lower()))
def detect_quant_scheme(model_name):
"""根据模型名字自动检测量化精度
- 如果模型名字包含 "int8" → "int8"
- 如果模型名字包含 "fp8" 且设备支持 → "fp8"
- 否则返回 None(表示不使用量化)
"""
if not model_name:
return None
# 延迟导入避免循环依赖
from utils.model_utils import is_fp8_supported_gpu
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
# 设备不支持fp8,返回None(使用默认精度)
return None
return None
def is_distill_model_from_name(model_name):
"""根据模型名称判断是否是 distill 模型"""
if not model_name:
return None
return "4step" in model_name.lower() or "8step" in model_name.lower()
def get_quant_scheme(quant_detected, quant_op_val):
if quant_op_val == "torch":
return f"{quant_detected}-torchao"
elif quant_op_val == "triton":
return f"{quant_detected}-triton"
else:
return f"{quant_detected}-{quant_op_val}"
def build_wan21(
model_path_input,
dit_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
quant_op,
):
wan21_config = {
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
}
dit_path_input = extract_model_name(dit_path_input) if dit_path_input else ""
t5_path_input = extract_model_name(t5_path_input) if t5_path_input else ""
clip_path_input = extract_model_name(clip_path_input) if clip_path_input else ""
is_distill = is_distill_model_from_name(dit_path_input)
if is_distill:
wan21_config["model_cls"] = "wan2.1_distill"
else:
wan21_config["model_cls"] = "wan2.1"
wan21_config["dit_quantized_ckpt"] = None
wan21_config["dit_original_ckpt"] = None
dit_quant_detected = detect_quant_scheme(dit_path_input)
is_dit_quant = dit_quant_detected in ["fp8", "int8"]
if is_dit_quant:
wan21_config["dit_quantized"] = True
wan21_config["dit_quant_scheme"] = get_quant_scheme(dit_quant_detected, quant_op)
wan21_config["dit_quantized_ckpt"] = os.path.join(model_path_input, dit_path_input)
else:
wan21_config["dit_quantized"] = False
wan21_config["dit_quant_scheme"] = "Default"
wan21_config["dit_original_ckpt"] = os.path.join(model_path_input, dit_path_input)
wan21_config["t5_original_ckpt"] = None
wan21_config["t5_quantized_ckpt"] = None
t5_quant_detected = detect_quant_scheme(t5_path_input)
is_t5_quant = t5_quant_detected in ["fp8", "int8"]
if is_t5_quant:
wan21_config["t5_quantized"] = True
wan21_config["t5_quant_scheme"] = get_quant_scheme(t5_quant_detected, quant_op)
wan21_config["t5_quantized_ckpt"] = os.path.join(model_path_input, t5_path_input)
else:
wan21_config["t5_quantized"] = False
wan21_config["t5_quant_scheme"] = "Default"
wan21_config["t5_original_ckpt"] = os.path.join(model_path_input, t5_path_input)
wan21_config["clip_original_ckpt"] = None
wan21_config["clip_quantized_ckpt"] = None
clip_quant_detected = detect_quant_scheme(clip_path_input)
is_clip_quant = clip_quant_detected in ["fp8", "int8"]
if is_clip_quant:
wan21_config["clip_quantized"] = True # 启用量化
wan21_config["clip_quant_scheme"] = get_quant_scheme(clip_quant_detected, quant_op)
wan21_config["clip_quantized_ckpt"] = os.path.join(model_path_input, clip_path_input)
else:
wan21_config["clip_quantized"] = False # 禁用量化
wan21_config["clip_quant_scheme"] = "Default"
wan21_config["clip_original_ckpt"] = os.path.join(model_path_input, clip_path_input)
# 提取 VAE 路径,移除状态符号
vae_path_input = extract_model_name(vae_path_input) if vae_path_input else ""
wan21_config["vae_path"] = os.path.join(model_path_input, vae_path_input) if vae_path_input else None
wan21_config["model_path"] = model_path_input
return wan21_config
def build_wan22(
model_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
vae_path_input,
quant_op,
):
wan22_config = {
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
}
"""构建 Wan2.2 模型配置"""
high_noise_path_input = extract_model_name(high_noise_path_input) if high_noise_path_input else ""
low_noise_path_input = extract_model_name(low_noise_path_input) if low_noise_path_input else ""
t5_path_input = extract_model_name(t5_path_input) if t5_path_input else ""
is_distill = is_distill_model_from_name(high_noise_path_input)
if is_distill:
wan22_config["model_cls"] = "wan2.2_moe_distill"
else:
wan22_config["model_cls"] = "wan2.2_moe"
wan22_config["high_noise_quantized_ckpt"] = None
wan22_config["low_noise_quantized_ckpt"] = None
wan22_config["high_noise_original_ckpt"] = None
wan22_config["low_noise_original_ckpt"] = None
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
is_dit_quant = dit_quant_detected in ["fp8", "int8"]
if is_dit_quant:
wan22_config["dit_quantized"] = True
wan22_config["dit_quant_scheme"] = get_quant_scheme(dit_quant_detected, quant_op)
wan22_config["high_noise_quant_scheme"] = get_quant_scheme(dit_quant_detected, quant_op)
wan22_config["high_noise_quantized_ckpt"] = os.path.join(model_path_input, high_noise_path_input)
wan22_config["low_noise_quantized_ckpt"] = os.path.join(model_path_input, low_noise_path_input)
else:
wan22_config["dit_quantized"] = False
wan22_config["dit_quant_scheme"] = "Default"
wan22_config["high_noise_quant_scheme"] = "Default"
wan22_config["high_noise_original_ckpt"] = os.path.join(model_path_input, high_noise_path_input)
wan22_config["low_noise_original_ckpt"] = os.path.join(model_path_input, low_noise_path_input)
wan22_config["t5_original_ckpt"] = None
wan22_config["t5_quantized_ckpt"] = None
t5_quant_detected = detect_quant_scheme(t5_path_input)
is_t5_quant = t5_quant_detected in ["fp8", "int8"]
if is_t5_quant:
wan22_config["t5_quantized"] = True
wan22_config["t5_quant_scheme"] = get_quant_scheme(t5_quant_detected, quant_op)
wan22_config["t5_quantized_ckpt"] = os.path.join(model_path_input, t5_path_input)
else:
wan22_config["t5_quantized"] = False
wan22_config["t5_quant_scheme"] = "Default"
wan22_config["t5_original_ckpt"] = os.path.join(model_path_input, t5_path_input)
# 提取 VAE 路径,移除状态符号
vae_path_input = extract_model_name(vae_path_input) if vae_path_input else ""
wan22_config["vae_path"] = os.path.join(model_path_input, vae_path_input) if vae_path_input else None
wan22_config["model_path"] = model_path_input
return wan22_config
def build_qwen_image(
model_type_input,
model_path_input,
qwen_image_dit_path_input,
qwen_image_vae_path_input,
qwen_image_scheduler_path_input,
qwen25vl_encoder_path_input,
quant_op,
):
# 共同配置
qwen_image_config = {
"model_cls": "qwen_image",
"attention_head_dim": 128,
"axes_dims_rope": [16, 56, 56],
"guidance_embeds": False,
"in_channels": 64,
"joint_attention_dim": 3584,
"num_attention_heads": 24,
"num_layers": 60,
"out_channels": 16,
"patch_size": 2,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"transformer_in_channels": 64,
}
# 根据模型类型设置不同配置
if model_type_input == "Qwen-Image-Edit-2511":
qwen_image_config.update(
{
"CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": True,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"vae_scale_factor": 8,
"zero_cond_t": True,
}
)
elif model_type_input == "Qwen-Image-2512":
qwen_image_config.update(
{
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34,
"zero_cond_t": False,
}
)
else:
raise ValueError(f"Invalid model type: {model_type_input}")
"""构建 Qwen-Image 模型配置"""
qwen_image_dit_path_input = extract_model_name(qwen_image_dit_path_input) if qwen_image_dit_path_input else ""
qwen_image_vae_path_input = extract_model_name(qwen_image_vae_path_input) if qwen_image_vae_path_input else ""
qwen_image_scheduler_path_input = extract_model_name(qwen_image_scheduler_path_input) if qwen_image_scheduler_path_input else ""
qwen25vl_encoder_path_input = extract_model_name(qwen25vl_encoder_path_input) if qwen25vl_encoder_path_input else ""
# 确定子目录路径
if model_type_input == "Qwen-Image-Edit-2511":
model_subdir = "Qwen-Image-Edit-2511"
elif model_type_input == "Qwen-Image-2512":
model_subdir = "Qwen-Image-2512"
else:
model_subdir = ""
# 处理包含子目录路径的输入(去掉子目录前缀,只保留文件名)
if qwen_image_dit_path_input and "/" in qwen_image_dit_path_input:
qwen_image_dit_path_input = qwen_image_dit_path_input.split("/")[-1]
if qwen_image_vae_path_input and "/" in qwen_image_vae_path_input:
qwen_image_vae_path_input = qwen_image_vae_path_input.split("/")[-1]
if qwen_image_scheduler_path_input and "/" in qwen_image_scheduler_path_input:
qwen_image_scheduler_path_input = qwen_image_scheduler_path_input.split("/")[-1]
qwen_dit_quant_detected = detect_quant_scheme(qwen_image_dit_path_input)
is_qwen_dit_quant = qwen_dit_quant_detected in ["fp8", "int8"]
if is_qwen_dit_quant:
qwen_image_config["dit_quantized"] = True
qwen_image_config["dit_quant_scheme"] = get_quant_scheme(qwen_dit_quant_detected, quant_op)
# 使用子目录路径构建完整路径
dit_path = os.path.join(model_path_input, model_subdir, qwen_image_dit_path_input) if qwen_image_dit_path_input else None
qwen_image_config["dit_quantized_ckpt"] = dit_path
qwen_image_config["dit_original_ckpt"] = None
else:
qwen_image_config["dit_quantized"] = False
qwen_image_config["dit_quant_scheme"] = "Default"
# 使用子目录路径构建完整路径
dit_path = os.path.join(model_path_input, model_subdir, qwen_image_dit_path_input) if qwen_image_dit_path_input else None
qwen_image_config["dit_original_ckpt"] = dit_path
qwen_image_config["dit_quantized_ckpt"] = None
# VAE 和 Scheduler 路径也使用子目录
vae_path = os.path.join(model_path_input, model_subdir, "vae") if qwen_image_vae_path_input else None
scheduler_path = os.path.join(model_path_input, model_subdir, "scheduler") if qwen_image_scheduler_path_input else None
qwen_image_config["vae_path"] = vae_path
qwen_image_config["scheduler_path"] = scheduler_path
qwen_image_config["qwen25vl_quantized"] = True
qwen_image_config["qwen25vl_quant_scheme"] = "int4"
qwen_image_config["qwen25vl_quantized_ckpt"] = os.path.join(model_path_input, qwen25vl_encoder_path_input) if qwen25vl_encoder_path_input else None
qwen_image_config["qwen25vl_tokenizer_path"] = os.path.join(model_path_input, qwen25vl_encoder_path_input) if qwen25vl_encoder_path_input else None
qwen_image_config["qwen25vl_processor_path"] = os.path.join(model_path_input, qwen25vl_encoder_path_input) if qwen25vl_encoder_path_input else None
# 使用子目录路径作为 model_path
if model_type_input == "Qwen-Image-Edit-2511":
qwen_image_config["model_path"] = os.path.join(model_path_input, "Qwen-Image-Edit-2511")
elif model_type_input == "Qwen-Image-2512":
qwen_image_config["model_path"] = os.path.join(model_path_input, "Qwen-Image-2512")
qwen_image_config["vae_scale_factor"] = 8
return qwen_image_config
def build_z_image(
model_path_input,
z_image_dit_path_input,
z_image_vae_path_input,
z_image_scheduler_path_input,
qwen3_encoder_path_input,
quant_op,
):
"""构建 Z-Image-Turbo 模型配置"""
# 共同配置
z_image_config = {
"all_f_patch_size": [1],
"all_patch_size": [2],
"axes_dims": [32, 48, 48],
"axes_lens": [1536, 512, 512],
"cap_feat_dim": 2560,
"dim": 3840,
"in_channels": 16,
"n_heads": 30,
"n_kv_heads": 30,
"n_layers": 30,
"n_refiner_layers": 2,
"norm_eps": 1e-05,
"qk_norm": True,
"rope_theta": 256.0,
"t_scale": 1000.0,
}
z_image_dit_path_input = extract_model_name(z_image_dit_path_input) if z_image_dit_path_input else ""
z_image_vae_path_input = extract_model_name(z_image_vae_path_input) if z_image_vae_path_input else ""
z_image_scheduler_path_input = extract_model_name(z_image_scheduler_path_input) if z_image_scheduler_path_input else ""
qwen3_encoder_path_input = extract_model_name(qwen3_encoder_path_input) if qwen3_encoder_path_input else ""
model_subdir = "Z-Image-Turbo"
# 处理包含子目录路径的输入(去掉子目录前缀,只保留文件名)
if z_image_dit_path_input and "/" in z_image_dit_path_input:
z_image_dit_path_input = z_image_dit_path_input.split("/")[-1]
if z_image_vae_path_input and "/" in z_image_vae_path_input:
z_image_vae_path_input = z_image_vae_path_input.split("/")[-1]
if z_image_scheduler_path_input and "/" in z_image_scheduler_path_input:
z_image_scheduler_path_input = z_image_scheduler_path_input.split("/")[-1]
z_image_dit_quant_detected = detect_quant_scheme(z_image_dit_path_input)
is_z_image_dit_quant = z_image_dit_quant_detected in ["fp8", "int8"]
if is_z_image_dit_quant:
z_image_config["dit_quantized"] = True
z_image_config["dit_quant_scheme"] = get_quant_scheme(z_image_dit_quant_detected, quant_op)
# 使用子目录路径构建完整路径
dit_path = os.path.join(model_path_input, model_subdir, z_image_dit_path_input) if z_image_dit_path_input else None
z_image_config["dit_quantized_ckpt"] = dit_path
z_image_config["dit_original_ckpt"] = None
else:
z_image_config["dit_quantized"] = False
z_image_config["dit_quant_scheme"] = "Default"
# 使用子目录路径构建完整路径
dit_path = os.path.join(model_path_input, model_subdir, z_image_dit_path_input) if z_image_dit_path_input else None
z_image_config["dit_original_ckpt"] = dit_path
z_image_config["dit_quantized_ckpt"] = None
# VAE 和 Scheduler 路径也使用子目录
vae_path = os.path.join(model_path_input, model_subdir, "vae") if z_image_vae_path_input else None
scheduler_path = os.path.join(model_path_input, model_subdir, "scheduler") if z_image_scheduler_path_input else None
z_image_config["vae_path"] = vae_path
z_image_config["scheduler_path"] = scheduler_path
z_image_config["qwen3_quantized"] = True
z_image_config["qwen3_quant_scheme"] = "int4"
z_image_config["qwen3_quantized_ckpt"] = os.path.join(model_path_input, qwen3_encoder_path_input) if qwen3_encoder_path_input else None
z_image_config["qwen3_tokenizer_path"] = os.path.join(model_path_input, qwen3_encoder_path_input) if qwen3_encoder_path_input else None
z_image_config["qwen3_processor_path"] = os.path.join(model_path_input, qwen3_encoder_path_input) if qwen3_encoder_path_input else None
# 使用子目录路径作为 model_path
z_image_config["model_path"] = os.path.join(model_path_input, "Z-Image-Turbo")
z_image_config["model_cls"] = "z_image"
z_image_config["vae_scale_factor"] = 8
z_image_config["patch_size"] = 2
return z_image_config
def get_model_configs(
model_type_input,
model_path_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
qwen_image_dit_path_input,
qwen_image_vae_path_input,
qwen_image_scheduler_path_input,
qwen25vl_encoder_path_input,
z_image_dit_path_input,
z_image_vae_path_input,
z_image_scheduler_path_input,
qwen3_encoder_path_input,
quant_op,
use_lora=None,
lora_path=None,
lora_strength=None,
high_noise_lora_path=None,
low_noise_lora_path=None,
high_noise_lora_strength=None,
low_noise_lora_strength=None,
):
if model_path_input and model_path_input.strip():
model_path_input = model_path_input.strip()
# 构建 LoRA 配置(如果提供)
lora_configs = None
if use_lora:
if model_type_input == "Wan2.2":
lora_configs = []
# 处理 high_noise LoRA
if high_noise_lora_path and high_noise_lora_path.strip():
high_noise_lora_full_path = os.path.join(model_path_input, "loras", high_noise_lora_path.strip())
high_noise_strength = float(high_noise_lora_strength) if high_noise_lora_strength is not None else 1.0
lora_configs.append(
{
"name": "high_noise_model",
"path": high_noise_lora_full_path,
"strength": high_noise_strength,
}
)
elif lora_path and lora_path.strip():
# 如果没有分别提供,使用统一的 lora_path
high_noise_lora_full_path = os.path.join(model_path_input, "loras", lora_path.strip())
high_noise_strength = float(lora_strength) if lora_strength is not None else 1.0
lora_configs.append(
{
"name": "high_noise_model",
"path": high_noise_lora_full_path,
"strength": high_noise_strength,
}
)
# 处理 low_noise LoRA
if low_noise_lora_path and low_noise_lora_path.strip():
low_noise_lora_full_path = os.path.join(model_path_input, "loras", low_noise_lora_path.strip())
low_noise_strength = float(low_noise_lora_strength) if low_noise_lora_strength is not None else 1.0
lora_configs.append(
{
"name": "low_noise_model",
"path": low_noise_lora_full_path,
"strength": low_noise_strength,
}
)
elif lora_path and lora_path.strip():
# 如果没有分别提供,使用统一的 lora_path
low_noise_lora_full_path = os.path.join(model_path_input, "loras", lora_path.strip())
low_noise_strength = float(lora_strength) if lora_strength is not None else 1.0
lora_configs.append(
{
"name": "low_noise_model",
"path": low_noise_lora_full_path,
"strength": low_noise_strength,
}
)
# 如果没有任何 LoRA 配置,设置为 None
if not lora_configs:
lora_configs = None
else:
# 其他模型类型(Wan2.1, Qwen-Image, Z-Image-Turbo)
if lora_path and lora_path.strip():
lora_full_path = os.path.join(model_path_input, "loras", lora_path.strip())
lora_configs = [
{
"path": lora_full_path,
"strength": float(lora_strength) if lora_strength is not None else 1.0,
}
]
if model_type_input == "Wan2.1":
config = build_wan21(
model_path_input,
dit_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
quant_op,
)
if lora_configs:
config["lora_configs"] = lora_configs
return config
elif model_type_input == "Wan2.2":
config = build_wan22(
model_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
vae_path_input,
quant_op,
)
if lora_configs:
config["lora_configs"] = lora_configs
return config
elif model_type_input in ["Qwen-Image-Edit-2511", "Qwen-Image-2512"]:
config = build_qwen_image(
model_type_input,
model_path_input,
qwen_image_dit_path_input,
qwen_image_vae_path_input,
qwen_image_scheduler_path_input,
qwen25vl_encoder_path_input,
quant_op,
)
if lora_configs:
config["lora_configs"] = lora_configs
return config
elif model_type_input == "Z-Image-Turbo":
config = build_z_image(
model_path_input,
z_image_dit_path_input,
z_image_vae_path_input,
z_image_scheduler_path_input,
qwen3_encoder_path_input,
quant_op,
)
if lora_configs:
config["lora_configs"] = lora_configs
return config
import random
from datetime import datetime
import gradio as gr
from loguru import logger
from utils.i18n import DEFAULT_LANG, t
from utils.image_page import build_image_page
from utils.model_utils import (
get_cpu_memory,
get_gpu_generation,
get_gpu_memory,
)
from utils.video_page import build_video_page
def get_gpu_rules(resolution):
"""根据分辨率获取 GPU 规则
Args:
resolution: 分辨率 ("480p", "540p", "720p")
Returns:
list: GPU 规则列表,每个元素为 (threshold, config_dict)
"""
if resolution in ["540p", "720p"]:
return [
(80, {}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"use_tiling_vae_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": True,
},
),
(
-1,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": True,
},
),
]
else:
return [
(80, {}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
(
-1,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"vae_cpu_offload_val": True,
"clip_cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
]
def get_cpu_rules():
"""获取 CPU 规则
Returns:
list: CPU 规则列表,每个元素为 (threshold, config_dict)
"""
return [
(128, {}),
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"lazy_load_val": True,
"unload_modules_val": True,
},
),
(
-1,
{
"t5_lazy_load": True,
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
css = """
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
/* 模型状态样式 */
.model-status {
margin: 0 !important;
padding: 0 !important;
font-size: 12px !important;
line-height: 1.2 !important;
min-height: 20px !important;
}
/* 模型配置区域样式 */
.model-config {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
}
/* 输入参数区域样式 */
.input-params {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
}
/* 输出视频区域样式 */
.output-video {
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
}
/* 生成按钮样式 */
.generate-btn {
width: 100%;
margin-top: 20px;
padding: 15px 30px !important;
font-size: 18px !important;
font-weight: bold !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
border-radius: 10px !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
transition: all 0.3s ease !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
/* Accordion 标题样式 */
.model-config .gr-accordion-header,
.input-params .gr-accordion-header,
.output-video .gr-accordion-header {
font-size: 20px !important;
font-weight: bold !important;
padding: 15px !important;
}
/* 优化间距 */
.gr-row {
margin-bottom: 15px;
}
/* 视频播放器样式 */
.output-video video {
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
/* Diffusion模型容器 */
.diffusion-model-group {
margin-bottom: 20px !important;
border: none !important;
outline: none !important;
box-shadow: none !important;
}
/* 移除 Gradio 组件的默认边框 */
.diffusion-model-group > div,
.diffusion-model-group .gr-group,
.diffusion-model-group .gr-box {
border: none !important;
outline: none !important;
box-shadow: none !important;
}
/* 编码器组容器(文本编码器、图像编码器) */
.encoder-group {
margin-bottom: 20px !important;
}
/* VAE组容器 */
.vae-group {
margin-bottom: 20px !important;
}
/* 模型组标题样式 */
.model-group-title {
font-size: 16px !important;
font-weight: 600 !important;
margin-bottom: 12px !important;
color: #24292f !important;
}
/* 下载按钮样式 */
.download-btn {
width: 100% !important;
margin-top: 8px !important;
border-radius: 6px !important;
transition: all 0.2s ease !important;
}
.download-btn:hover {
transform: translateY(-1px) !important;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
}
/* 水平排列的Radio按钮 */
.horizontal-radio .form-radio {
display: flex !important;
flex-direction: row !important;
gap: 20px !important;
}
.horizontal-radio .form-radio > label {
margin-right: 20px !important;
}
/* wan2.2 行样式 - 去掉上边框和分隔线 */
.wan22-row {
border-top: none !important;
border-bottom: none !important;
margin-top: 0 !important;
padding-top: 0 !important;
}
.wan22-row > div {
border-top: none !important;
border-bottom: none !important;
}
.wan22-row .gr-column {
border-top: none !important;
border-bottom: none !important;
border-left: none !important;
border-right: none !important;
}
.wan22-row .gr-column:first-child {
border-right: none !important;
}
.wan22-row .gr-column:last-child {
border-left: none !important;
}
"""
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
"""生成随机种子"""
return random.randint(0, MAX_NUMPY_SEED)
def generate_unique_filename(output_dir, is_image=False):
"""生成唯一文件名"""
import os
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
ext = ".png" if is_image else ".mp4"
return os.path.join(output_dir, f"{timestamp}{ext}")
def get_auto_config_dict(
model_type,
resolution,
num_frames=81,
task_type=None,
):
if task_type in ["i2i", "t2i"]:
attention_type_val = "torch_sdpa"
else:
attention_type_val = "sage_attn2"
gpu_gen = get_gpu_generation()
if gpu_gen == "40":
quant_op_val = "q8f"
elif gpu_gen == "30":
quant_op_val = "vllm"
else:
quant_op_val = "triton"
default_config = {
"lazy_load_val": False,
"rope_chunk_val": False,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": False,
"cpu_offload_val": False,
"offload_granularity_val": "block",
"t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False,
"attention_type_val": attention_type_val,
"quant_op_val": quant_op_val,
"use_tiling_vae_val": False,
}
# If num_frames > 81, set rope_chunk to True regardless of resolution
if num_frames is not None and num_frames > 81:
default_config["rope_chunk_val"] = True
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
gpu_rules = get_gpu_rules(resolution)
cpu_rules = get_cpu_rules()
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
default_config.update(updates)
break
for threshold, updates in cpu_rules:
if cpu_memory >= threshold:
default_config.update(updates)
break
if model_type == "Z-Image-Turbo":
default_config["lazy_load_val"] = False
if default_config["cpu_offload_val"]:
default_config["offload_granularity_val"] = "model"
default_config["quant_op_val"] = "triton"
return default_config
def build_ui(
model_path,
output_dir,
run_inference,
lang=DEFAULT_LANG,
):
# 在启动时加载 Hugging Face 模型列表缓存
logger.info(t("loading_models", lang))
from utils.model_utils import load_hf_models_cache
load_hf_models_cache()
logger.info(t("models_loaded", lang))
with gr.Blocks(title="Lightx2v (轻量级视频推理和生成引擎)") as demo:
gr.Markdown(f"# 🎬 LightX2V 图片/视频生成器")
gr.HTML(f"<style>{css}</style>")
# 使用 Tabs 分成两个页面
with gr.Tabs():
# 图片生成页面
with gr.Tab("🖼️ 图片生成"):
build_image_page(
model_path,
output_dir,
run_inference,
lang,
)
# 视频生成页面
with gr.Tab("🎬 视频生成"):
build_video_page(
model_path,
output_dir,
run_inference,
lang,
)
# 返回 demo 对象,由调用者负责 launch
return demo
import random
import gradio as gr
from utils.i18n import DEFAULT_LANG, t
from utils.model_choices import (
get_clip_model_choices,
get_dit_choices,
get_high_noise_choices,
get_low_noise_choices,
get_t5_model_choices,
get_vae_choices,
)
from utils.model_components import build_wan21_components, build_wan22_components
from utils.model_handlers import TOKENIZER_CONFIG, create_download_wrappers, create_update_status_wrappers
from utils.model_utils import HF_AVAILABLE, MS_AVAILABLE, check_model_exists, extract_model_name, is_fp8_supported_gpu
MAX_NUMPY_SEED = 2**32 - 1
def generate_random_seed():
"""生成随机种子"""
return random.randint(0, MAX_NUMPY_SEED)
def generate_unique_filename(output_dir, is_image=False):
"""生成唯一文件名"""
import os
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
ext = ".png" if is_image else ".mp4"
filename = f"{timestamp}{ext}"
return os.path.join(output_dir, filename)
def build_video_page(
model_path,
output_dir,
run_inference,
lang=DEFAULT_LANG,
):
"""构建视频生成页面 - 重构简化版"""
# 创建 update 和 download 包装函数
update_funcs = create_update_status_wrappers()
get_choices_funcs = {
"get_dit_choices": get_dit_choices,
"get_high_noise_choices": get_high_noise_choices,
"get_low_noise_choices": get_low_noise_choices,
"get_t5_model_choices": get_t5_model_choices,
"get_clip_model_choices": get_clip_model_choices,
"get_vae_choices": get_vae_choices,
}
download_funcs = create_download_wrappers(get_choices_funcs)
# 判断模型是否是 distill 版本
def is_distill_model(model_type, dit_path, high_noise_path):
if model_type == "Wan2.1":
check_name = dit_path.lower() if dit_path else ""
else:
check_name = high_noise_path.lower() if high_noise_path else ""
return "4step" in check_name
# 主布局:左右分栏
with gr.Row():
# 左侧:配置和输入区域
with gr.Column(scale=5):
# 模型配置区域
with gr.Accordion(t("model_config", lang), open=True, elem_classes=["model-config"]):
gr.Markdown(t("model_config_hint", lang))
# FP8 支持提示
if not is_fp8_supported_gpu():
gr.Markdown(t("fp8_not_supported", lang))
model_path_input = gr.Textbox(value=model_path, visible=False)
# 模型类型 + 任务类型 + 下载源
with gr.Row():
model_type_input = gr.Radio(
label=t("model_type", lang),
choices=["Wan2.1", "Wan2.2"],
value="Wan2.1",
info=t("model_type_info", lang),
elem_classes=["horizontal-radio"],
)
task_type_input = gr.Radio(
label=t("task_type", lang),
choices=[("I2V", "i2v"), ("T2V", "t2v")],
value="i2v",
info=t("task_type_info", lang),
elem_classes=["horizontal-radio"],
)
download_source_input = gr.Radio(
label=t("download_source", lang),
choices=(["huggingface", "modelscope"] if (HF_AVAILABLE and MS_AVAILABLE) else (["huggingface"] if HF_AVAILABLE else ["modelscope"] if MS_AVAILABLE else [])),
value=("modelscope" if MS_AVAILABLE else ("huggingface" if HF_AVAILABLE else None)),
info=t("download_source_info", lang),
visible=HF_AVAILABLE or MS_AVAILABLE,
elem_classes=["horizontal-radio"],
)
# wan2.1 和 wan2.2 组件
wan21_components = build_wan21_components(model_path, model_path_input, model_type_input, task_type_input, download_source_input, update_funcs, download_funcs, lang)
wan21_row = wan21_components["wan21_row"]
dit_path_input = wan21_components["dit_path_input"]
dit_download_btn = wan21_components["dit_download_btn"]
wan21_use_lora_input = wan21_components["use_lora"]
wan21_lora_path_input = wan21_components["lora_path_input"]
wan21_lora_strength_input = wan21_components["lora_strength"]
wan22_components = build_wan22_components(model_path, model_path_input, task_type_input, download_source_input, update_funcs, download_funcs, lang)
wan22_row = wan22_components["wan22_row"]
high_noise_path_input = wan22_components["high_noise_path_input"]
low_noise_path_input = wan22_components["low_noise_path_input"]
high_noise_download_btn = wan22_components["high_noise_download_btn"]
low_noise_download_btn = wan22_components["low_noise_download_btn"]
wan22_use_lora_input = wan22_components["use_lora"]
wan22_high_noise_lora_path_input = wan22_components["high_noise_lora_path_input"]
wan22_high_noise_lora_strength_input = wan22_components["high_noise_lora_strength"]
wan22_low_noise_lora_path_input = wan22_components["low_noise_lora_path_input"]
wan22_low_noise_lora_strength_input = wan22_components["low_noise_lora_strength"]
# 文本编码器(模型 + Tokenizer)
with gr.Row() as t5_row:
with gr.Column(scale=1):
t5_model_choices_init = get_t5_model_choices(model_path)
t5_path_input = gr.Dropdown(
label=t("text_encoder", lang),
choices=t5_model_choices_init,
value=t5_model_choices_init[0] if t5_model_choices_init else "",
allow_custom_value=True,
)
# 初始化时检查 T5 模型状态
t5_btn_visible = False
if t5_model_choices_init:
first_choice = t5_model_choices_init[0]
actual_name = extract_model_name(first_choice)
t5_exists = check_model_exists(model_path, actual_name)
t5_btn_visible = not t5_exists
t5_download_btn = gr.Button(t("download", lang), visible=t5_btn_visible, size="sm", variant="secondary")
t5_download_status = gr.Markdown("", visible=False)
with gr.Column(scale=1):
# 初始化时检查 T5 Tokenizer 状态
t5_tokenizer_config = TOKENIZER_CONFIG.get("t5")
t5_tokenizer_name = t5_tokenizer_config["name"] if t5_tokenizer_config else "google"
t5_tokenizer_exists = check_model_exists(model_path, t5_tokenizer_name) if t5_tokenizer_config else False
t5_tokenizer_status_text = f"{t5_tokenizer_name} ✅" if t5_tokenizer_exists else f"{t5_tokenizer_name} ❌"
t5_tokenizer_btn_visible = not t5_tokenizer_exists
t5_tokenizer_hint = gr.Dropdown(
label=t("text_encoder_tokenizer", lang),
choices=["google ✅", "google ❌"],
value=t5_tokenizer_status_text,
interactive=False,
)
t5_tokenizer_download_btn = gr.Button(t("download", lang), visible=t5_tokenizer_btn_visible, size="sm", variant="secondary")
t5_tokenizer_download_status = gr.Markdown("", visible=False)
# 图像编码器(模型 + Tokenizer,条件显示)
with gr.Row(visible=True) as clip_row:
with gr.Column(scale=1):
clip_model_choices_init = get_clip_model_choices(model_path)
clip_path_input = gr.Dropdown(
label=t("image_encoder", lang),
choices=clip_model_choices_init,
value=clip_model_choices_init[0] if clip_model_choices_init else "",
allow_custom_value=True,
)
# 初始化时检查 CLIP 模型状态
clip_btn_visible = False
if clip_model_choices_init:
first_choice = clip_model_choices_init[0]
actual_name = extract_model_name(first_choice)
clip_exists = check_model_exists(model_path, actual_name)
clip_btn_visible = not clip_exists
clip_download_btn = gr.Button(t("download", lang), visible=clip_btn_visible, size="sm", variant="secondary")
clip_download_status = gr.Markdown("", visible=False)
with gr.Column(scale=1):
# 初始化时检查 CLIP Tokenizer 状态
clip_tokenizer_config = TOKENIZER_CONFIG.get("clip")
clip_tokenizer_name = clip_tokenizer_config["name"] if clip_tokenizer_config else "xlm-roberta-large"
clip_tokenizer_exists = check_model_exists(model_path, clip_tokenizer_name) if clip_tokenizer_config else False
clip_tokenizer_status_text = f"{clip_tokenizer_name} ✅" if clip_tokenizer_exists else f"{clip_tokenizer_name} ❌"
clip_tokenizer_btn_visible = not clip_tokenizer_exists
clip_tokenizer_hint = gr.Dropdown(
label=t("image_encoder_tokenizer", lang),
choices=["xlm-roberta-large ✅", "xlm-roberta-large ❌"],
value=clip_tokenizer_status_text,
interactive=False,
)
clip_tokenizer_download_btn = gr.Button(t("download", lang), visible=clip_tokenizer_btn_visible, size="sm", variant="secondary")
clip_tokenizer_download_status = gr.Markdown("", visible=False)
# VAE
with gr.Row() as vae_row:
vae_choices_init = get_vae_choices(model_path)
vae_path_input = gr.Dropdown(
label=t("vae", lang),
choices=vae_choices_init,
value=(vae_choices_init[0] if vae_choices_init else ""),
allow_custom_value=True,
interactive=True,
)
# 初始化时检查 VAE 状态
vae_btn_visible = False
if vae_choices_init:
first_choice = vae_choices_init[0]
actual_name = extract_model_name(first_choice)
vae_exists = check_model_exists(model_path, actual_name)
vae_btn_visible = not vae_exists
vae_download_btn = gr.Button(t("download", lang), visible=vae_btn_visible, size="sm", variant="secondary")
vae_download_status = gr.Markdown("", visible=False)
# 使用预创建的包装函数
update_dit_status = update_funcs["update_dit_status"]
update_high_noise_status = update_funcs["update_high_noise_status"]
update_low_noise_status = update_funcs["update_low_noise_status"]
update_t5_model_status = update_funcs["update_t5_model_status"]
update_t5_tokenizer_status = update_funcs["update_t5_tokenizer_status"]
update_clip_model_status = update_funcs["update_clip_model_status"]
update_clip_tokenizer_status = update_funcs["update_clip_tokenizer_status"]
update_vae_status = update_funcs["update_vae_status"]
download_t5_model = download_funcs["download_t5_model"]
download_t5_tokenizer = download_funcs["download_t5_tokenizer"]
download_clip_model = download_funcs["download_clip_model"]
download_clip_tokenizer = download_funcs["download_clip_tokenizer"]
download_vae = download_funcs["download_vae"]
# T5 模型路径变化时,只更新 T5 模型状态(不更新 Tokenizer)
t5_path_input.change(
fn=update_t5_model_status,
inputs=[model_path_input, t5_path_input],
outputs=[t5_download_btn],
)
# CLIP 模型路径变化时,只更新 CLIP 模型状态(不更新 Tokenizer)
clip_path_input.change(
fn=update_clip_model_status,
inputs=[model_path_input, clip_path_input],
outputs=[clip_download_btn],
)
# Tokenizer 独立检测(当 model_path 变化时,同时更新所有 Tokenizer)
def update_all_tokenizers(model_path_val):
"""更新所有 Tokenizer 状态"""
t5_tokenizer_result = update_t5_tokenizer_status(model_path_val)
clip_tokenizer_result = update_clip_tokenizer_status(model_path_val)
return (
t5_tokenizer_result[0], # t5_tokenizer_hint
t5_tokenizer_result[1], # t5_tokenizer_download_btn
clip_tokenizer_result[0], # clip_tokenizer_hint
clip_tokenizer_result[1], # clip_tokenizer_download_btn
)
# 绑定 model_path 变化事件,更新所有 Tokenizer 状态
model_path_input.change(
fn=update_all_tokenizers,
inputs=[model_path_input],
outputs=[
t5_tokenizer_hint,
t5_tokenizer_download_btn,
clip_tokenizer_hint,
clip_tokenizer_download_btn,
],
)
vae_path_input.change(
fn=update_vae_status,
inputs=[model_path_input, vae_path_input],
outputs=[vae_download_btn],
)
# 绑定下载按钮事件
t5_download_btn.click(
fn=download_t5_model,
inputs=[model_path_input, t5_path_input, download_source_input],
outputs=[t5_download_status, t5_download_btn, t5_path_input],
)
t5_tokenizer_download_btn.click(
fn=download_t5_tokenizer,
inputs=[model_path_input, download_source_input],
outputs=[
t5_tokenizer_download_status,
t5_tokenizer_hint,
t5_tokenizer_download_btn,
],
)
clip_download_btn.click(
fn=download_clip_model,
inputs=[model_path_input, clip_path_input, download_source_input],
outputs=[clip_download_status, clip_download_btn, clip_path_input],
)
clip_tokenizer_download_btn.click(
fn=download_clip_tokenizer,
inputs=[model_path_input, download_source_input],
outputs=[
clip_tokenizer_download_status,
clip_tokenizer_hint,
clip_tokenizer_download_btn,
],
)
vae_download_btn.click(
fn=download_vae,
inputs=[
model_path_input,
vae_path_input,
download_source_input,
],
outputs=[
vae_download_status,
vae_download_btn,
vae_path_input,
],
)
# 统一的模型组件更新函数
def update_model_components(model_type, task_type, model_path_val):
"""统一更新所有模型组件"""
show_clip = model_type == "Wan2.1" and task_type == "i2v"
show_image_input = task_type == "i2v"
is_wan21 = model_type == "Wan2.1"
# 获取模型选项
t5_choices = get_t5_model_choices(model_path_val)
vae_choices = get_vae_choices(model_path_val)
clip_choices = get_clip_model_choices(model_path_val) if show_clip else []
# 更新 Tokenizer 状态
t5_tokenizer_result = update_t5_tokenizer_status(model_path_val)
clip_tokenizer_result = update_clip_tokenizer_status(model_path_val)
# 更新模型下载按钮状态
t5_btn_update = update_t5_model_status(model_path_val, t5_choices[0] if t5_choices else "")
clip_btn_update = update_clip_model_status(model_path_val, clip_choices[0] if clip_choices else "") if show_clip else gr.update()
vae_btn_update = update_vae_status(model_path_val, vae_choices[0] if vae_choices else "")
if is_wan21:
dit_choices = get_dit_choices(model_path_val, "wan2.1", task_type)
# 更新 DIT 下载按钮状态
from utils.model_utils import extract_model_name
dit_btn_update = update_dit_status(model_path_val, extract_model_name(dit_choices[0]) if dit_choices else "", "wan2.1") if dit_choices else gr.update(visible=False)
return (
gr.update(visible=True), # wan21_row
gr.update(visible=False), # wan22_row
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else "", visible=True), # dit_path_input
gr.update(), # high_noise_path_input
gr.update(), # low_noise_path_input
gr.update(visible=show_clip), # clip_row
gr.update(visible=True), # vae_row
gr.update(visible=True), # t5_row
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""), # t5_path_input
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""), # clip_path_input
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), # vae_path_input
gr.update(visible=show_image_input), # image_input_row
gr.update(label=t("output_video_path", lang)), # save_result_path
t5_tokenizer_result[0], # t5_tokenizer_hint
t5_tokenizer_result[1], # t5_tokenizer_download_btn
clip_tokenizer_result[0], # clip_tokenizer_hint
clip_tokenizer_result[1], # clip_tokenizer_download_btn
t5_btn_update, # t5_download_btn
clip_btn_update, # clip_download_btn
vae_btn_update, # vae_download_btn
dit_btn_update, # dit_download_btn
gr.update(), # high_noise_download_btn
gr.update(), # low_noise_download_btn
)
else: # wan2.2
high_noise_choices = get_high_noise_choices(model_path_val, "wan2.2", task_type)
low_noise_choices = get_low_noise_choices(model_path_val, "wan2.2", task_type)
# 更新 high_noise 和 low_noise 下载按钮状态
from utils.model_utils import extract_model_name
high_noise_btn_update = (
update_high_noise_status(model_path_val, extract_model_name(high_noise_choices[0]) if high_noise_choices else "") if high_noise_choices else gr.update(visible=False)
)
low_noise_btn_update = (
update_low_noise_status(model_path_val, extract_model_name(low_noise_choices[0]) if low_noise_choices else "") if low_noise_choices else gr.update(visible=False)
)
return (
gr.update(visible=False), # wan21_row
gr.update(visible=True), # wan22_row
gr.update(visible=False), # dit_path_input
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""), # high_noise_path_input
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""), # low_noise_path_input
gr.update(visible=show_clip), # clip_row
gr.update(visible=True), # vae_row
gr.update(visible=True), # t5_row
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""), # t5_path_input
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""), # clip_path_input
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""), # vae_path_input
gr.update(visible=show_image_input), # image_input_row
gr.update(label=t("output_video_path", lang)), # save_result_path
t5_tokenizer_result[0], # t5_tokenizer_hint
t5_tokenizer_result[1], # t5_tokenizer_download_btn
clip_tokenizer_result[0], # clip_tokenizer_hint
clip_tokenizer_result[1], # clip_tokenizer_download_btn
t5_btn_update, # t5_download_btn
clip_btn_update, # clip_download_btn
vae_btn_update, # vae_download_btn
gr.update(), # dit_download_btn
high_noise_btn_update, # high_noise_download_btn
low_noise_btn_update, # low_noise_download_btn
)
# 输入参数区域
with gr.Accordion(t("input_params", lang), open=True, elem_classes=["input-params"]):
# 图片输入(i2v 时显示)
with gr.Column(visible=True) as image_input_row:
image_files = gr.File(
label=t("input_images", lang),
file_count="multiple",
file_types=["image"],
height=150,
interactive=True,
)
# 图片预览 Gallery
image_gallery = gr.Gallery(
label=t("uploaded_images_preview", lang),
columns=4,
rows=2,
height=200,
object_fit="contain",
show_label=True,
)
# 将多个文件路径转换为逗号分隔的字符串
image_path = gr.Textbox(
label=t("image_path", lang),
visible=False,
)
def update_image_path_and_gallery(files):
if files is None or len(files) == 0:
return "", []
# 提取文件路径
paths = [f.name if hasattr(f, "name") else f for f in files]
# 返回逗号分隔的路径和图片列表用于 Gallery 显示
return ",".join(paths), paths
image_files.change(
fn=update_image_path_and_gallery,
inputs=[image_files],
outputs=[image_path, image_gallery],
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label=t("prompt", lang),
lines=3,
placeholder=t("prompt_placeholder", lang),
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label=t("negative_prompt", lang),
lines=3,
placeholder=t("negative_prompt_placeholder", lang),
max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
)
with gr.Column(visible=True):
resolution = gr.Dropdown(
choices=["480p", "540p", "720p"],
value="480p",
label=t("max_resolution", lang),
info=t("max_resolution_info", lang),
)
with gr.Column(scale=9):
seed = gr.Slider(
label=t("random_seed", lang),
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column():
default_dit = get_dit_choices(model_path, "wan2.1", "i2v")[0] if get_dit_choices(model_path, "wan2.1", "i2v") else ""
default_high_noise = get_high_noise_choices(model_path, "wan2.2", "i2v")[0] if get_high_noise_choices(model_path, "wan2.2", "i2v") else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label=t("infer_steps", lang),
minimum=1,
maximum=100,
step=1,
value=4,
info=t("infer_steps_distill", lang),
)
else:
infer_steps = gr.Slider(
label=t("infer_steps", lang),
minimum=1,
maximum=100,
step=1,
value=40,
)
# 根据模型类别设置默认CFG
default_cfg_scale = 1 if default_is_distill else 5
with gr.Row():
sample_shift = gr.Slider(
label=t("sample_shift", lang),
value=5,
minimum=0,
maximum=10,
step=1,
info=t("sample_shift_info", lang),
)
cfg_scale = gr.Slider(
label=t("cfg_scale", lang),
minimum=1,
maximum=10,
step=1,
value=default_cfg_scale,
info=t("cfg_scale_info", lang),
)
# 统一的模型参数更新函数(推理步数和CFG)
def update_model_params(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
return (
gr.update(value=4 if is_distill else 40), # infer_steps
gr.update(value=1 if is_distill else 5), # cfg_scale
)
# 监听模型路径和类型变化
inputs = [model_type_input, dit_path_input, high_noise_path_input]
for trigger in [dit_path_input, high_noise_path_input, model_type_input]:
trigger.change(
fn=update_model_params,
inputs=inputs,
outputs=[infer_steps, cfg_scale],
)
with gr.Row(visible=True):
fps = gr.Slider(
label=t("fps", lang),
minimum=8,
maximum=30,
step=1,
value=16,
info=t("fps_info", lang),
)
video_duration = gr.Slider(
label=t("video_duration", lang),
minimum=1.0,
maximum=10.0,
step=0.1,
value=5.0,
info=t("video_duration_info", lang),
)
save_result_path = gr.Textbox(
label=t("output_video_path", lang),
value=generate_unique_filename(output_dir),
info=t("output_video_path_info", lang),
visible=False, # 隐藏输出路径,自动生成
)
# 模型类型和任务类型切换都使用同一个函数(在 image_input_row 和 save_result_path 定义后绑定)
model_type_input.change(
fn=update_model_components,
inputs=[model_type_input, task_type_input, model_path_input],
outputs=[
wan21_row,
wan22_row,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
clip_row,
vae_row,
t5_row,
t5_path_input,
clip_path_input,
vae_path_input,
image_input_row,
save_result_path,
t5_tokenizer_hint,
t5_tokenizer_download_btn,
clip_tokenizer_hint,
clip_tokenizer_download_btn,
t5_download_btn,
clip_download_btn,
vae_download_btn,
dit_download_btn,
high_noise_download_btn,
low_noise_download_btn,
],
)
task_type_input.change(
fn=update_model_components,
inputs=[model_type_input, task_type_input, model_path_input],
outputs=[
wan21_row,
wan22_row,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
clip_row,
vae_row,
t5_row,
t5_path_input,
clip_path_input,
vae_path_input,
image_input_row,
save_result_path,
t5_tokenizer_hint,
t5_tokenizer_download_btn,
clip_tokenizer_hint,
clip_tokenizer_download_btn,
t5_download_btn,
clip_download_btn,
vae_download_btn,
dit_download_btn,
high_noise_download_btn,
low_noise_download_btn,
],
)
# 右侧:输出区域
with gr.Column(scale=4):
with gr.Accordion(t("output_result", lang), open=True, elem_classes=["output-video"]):
output_video = gr.Video(
label="",
height=600,
autoplay=True,
show_label=False,
visible=True,
)
infer_btn = gr.Button(
t("generate_video", lang),
variant="primary",
size="lg",
elem_classes=["generate-btn"],
)
# 包装推理函数,视频页面只返回视频,只传递必要的参数
def run_inference_wrapper(
prompt_val,
negative_prompt_val,
save_result_path_val,
infer_steps_val,
video_duration_val,
resolution_val,
seed_val,
sample_shift_val,
cfg_scale_val,
fps_val,
model_path_val,
model_type_val,
task_type_val,
dit_path_val,
high_noise_path_val,
low_noise_path_val,
t5_path_val,
clip_path_val,
vae_path_val,
image_path_val,
wan21_use_lora_val,
wan21_lora_path_val,
wan21_lora_strength_val,
wan22_use_lora_val,
wan22_high_noise_lora_path_val,
wan22_high_noise_lora_strength_val,
wan22_low_noise_lora_path_val,
wan22_low_noise_lora_strength_val,
):
# 视频页面不需要 qwen_image 相关参数,使用 None
# offload 和 rope 相关参数已在 run_inference 中通过 get_auto_config_dict 自动决定
# 将视频时长(秒)转换为帧数
num_frames_val = int(video_duration_val * fps_val) + 1
# 提取所有路径参数中的模型名,移除状态符号(✅/❌)
from utils.model_utils import extract_model_name
dit_path_val = extract_model_name(dit_path_val) if dit_path_val else None
high_noise_path_val = extract_model_name(high_noise_path_val) if high_noise_path_val else None
low_noise_path_val = extract_model_name(low_noise_path_val) if low_noise_path_val else None
t5_path_val = extract_model_name(t5_path_val) if t5_path_val else None
clip_path_val = extract_model_name(clip_path_val) if clip_path_val else ""
vae_path_val = extract_model_name(vae_path_val) if vae_path_val else None
if model_type_val == "Wan2.1":
use_lora = wan21_use_lora_val
lora_path = wan21_lora_path_val
lora_strength = wan21_lora_strength_val
low_noise_lora_path = None
low_noise_lora_strength = None
high_noise_lora_path = None
high_noise_lora_strength = None
else:
use_lora = wan22_use_lora_val
high_noise_lora_path = wan22_high_noise_lora_path_val
high_noise_lora_strength = wan22_high_noise_lora_strength_val
low_noise_lora_path = wan22_low_noise_lora_path_val
low_noise_lora_strength = wan22_low_noise_lora_strength_val
lora_path = None
lora_strength = None
result = run_inference(
prompt=prompt_val,
negative_prompt=negative_prompt_val,
save_result_path=save_result_path_val,
infer_steps=infer_steps_val,
num_frames=num_frames_val,
resolution=resolution_val,
seed=seed_val,
sample_shift=sample_shift_val,
cfg_scale=cfg_scale_val, # enable_cfg 会根据 cfg_scale 在 run_inference 内部自动设置
fps=fps_val,
model_path_input=model_path_val,
model_type_input=model_type_val,
task_type_input=task_type_val,
dit_path_input=dit_path_val,
high_noise_path_input=high_noise_path_val,
low_noise_path_input=low_noise_path_val,
t5_path_input=t5_path_val,
clip_path_input=clip_path_val,
vae_path=vae_path_val,
image_path=image_path_val,
use_lora=use_lora,
lora_path=lora_path,
lora_strength=lora_strength,
high_noise_lora_path=high_noise_lora_path,
high_noise_lora_strength=high_noise_lora_strength,
low_noise_lora_path=low_noise_lora_path,
low_noise_lora_strength=low_noise_lora_strength,
)
return gr.update(value=result)
infer_btn.click(
fn=run_inference_wrapper,
inputs=[
prompt,
negative_prompt,
save_result_path,
infer_steps,
video_duration,
resolution,
seed,
sample_shift,
cfg_scale,
fps,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path,
wan21_use_lora_input,
wan21_lora_path_input,
wan21_lora_strength_input,
wan22_use_lora_input,
wan22_high_noise_lora_path_input,
wan22_high_noise_lora_strength_input,
wan22_low_noise_lora_path_input,
wan22_low_noise_lora_strength_input,
],
outputs=[output_video],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment